Pytorch的Dataloader使用详解

发布于:2025-05-15 ⋅ 阅读:(15) ⋅ 点赞:(0)

PyTorch 的 DataLoader 是数据加载的核心组件,它能高效地批量加载数据并进行预处理。

Pytorch DataLoader基础概念

DataLoader基础概念
DataLoader是PyTorch基础概念
DataLoader是PyTorch中用于加载数据的工具,它可以:批量加载数据(batch loading)打乱数据(shuffling)并行加载数据(多线程)
自定义数据加载方式Dataloader的基本使用from torch.utils.data import Dataset, DataLoader

自定义数据集类

class MyDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels
    
    def __getitem__(self, index):
        return self.data[index], self.labels[index]
    
    def __len__(self):
        return len(self.data)

创建数据集实例

dataset = MyDataset(data, labels)

创建DataLoader

dataloader = DataLoader(
    dataset=dataset,      # 数据集
    batch_size=32,        # 批次大小
    shuffle=True,         # 是否打乱数据
    num_workers=4,        # 多进程加载数据的线程数
    drop_last=False       # 当样本数不能被batch_size整除时,是否丢弃最后一个不完整的batch
)
# 使用DataLoader迭代数据
for batch_data, batch_labels in dataloader:
    # 训练或推理代码
    pass

DataLoader重要参数详解

  1. dataset: 要加载的数据集,必须是Dataset类的实例 batch_size: 每个批次的样本数
  2. shuffle:是否在每个epoch重新打乱数据
  3. sampler:自定义从数据集中抽取样本的策略,如果指定了sampler,则shuffle必须为False
  4. num_workers:使用多少个子进程加载数据,0表示在主进程中加载。
  5. collate_fn:将一批数据整合成一个批次的函数,特别使用于处理不同长度的序列数据
  6. Pin_memory:如果为True,数据加载器会将张量复制到CUDA固定内存中,加速CPU到GPU的数据传输
  7. drop_last: 如果数据集大小不能被batch_size整除,是否丢弃最后一个不完整的批次。
  8. timeout:收集一个批次的超时值
  9. worker_init_fn:每个worker初始化时被调用的函数
  10. weight_sampler:参数决定是都使用加权采样器来平衡类别分布
if infinite_data_loader:
    data_loader = InfiniteDataLoader(
        dataset=data,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        sampler=sampler,
        **kwargs
    )
else:
    data_loader = DataLoader(
        dataset=data,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        sampler=sampler,
        **kwargs
    )

n_class = len(data.classes)
return data_loader, n_class
这段代码决定了如何创建数据加载器,根据infinite_data_loader参数选择不同的加载器类型:
if infinite_data_loader:
    data_loader = InfiniteDataLoader(
        dataset=data,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        sampler=sampler,
        **kwargs
    )
else:
    data_loader = DataLoader(
        dataset=data,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        sampler=sampler,
        **kwargs
    )

n_class = len(data.classes)
return data_loader, n_class

代码解析

这段代码基于infinite_data_loader参数创建不同类型的数据加载器:
当infinite_data_loader为True时:
创建InfiniteDataLoader实例
自定义的无限循环数据加载器,会持续提供数据而不会在一个epoch结束时停止
当infinite_data_loader为False时:
创建标准的PyTorch DataLoader实例
这是普通的数据加载器,一个epoch结束后会停止

共同参数:

dataset=data:要加载的数据集
batch_size=batch_size:每批数据的大小
shuffle=shuffle:是否打乱数据(之前代码中已设置)
num_workers=num_workers:用于并行加载数据的线程数
sampler=sampler:用于采样的策略(之前代码中已设置,可能是加权采样器)
**kwargs:其他可能的参数,如pin_memory、drop_last等

返回值:

data_loader:创建好的数据加载器
n_class = len(data.classes):数据集中的类别数量
InfiniteDataLoader的作用
在您的代码中定义了两种InfiniteDataLoader实现:一种作为DataLoader的子类,另一种是完全自定义的类。它们的共同目的是:
持续提供数据:当一个epoch结束后,自动重新开始,不会引发StopIteration异常
支持长时间训练:在需要长时间训练的场景中特别有用,如半监督学习或者领域适应
避免手动重置:不需要在每个epoch结束后手动重置数据加载器

使用场景

无限数据加载器特别适用于:
持续训练:模型需要无限期地训练,如自监督学习或强化学习
不均匀更新:源域和目标域数据需要不同频率的更新
流式训练:数据以流的形式到达,不需要明确的epoch边界
基于迭代而非epoch的训练:训练基于迭代次数而非数据epoch
最后的返回值n_class提供了数据集的类别数量,这对模型构建和评估都很重要,比如设置分类层的输出维度或计算平均类别准确率。
高级用法

1.自定义collate_fn处理变长序列

def collate_fn(batch):
    # 排序批次数据,按序列长度降序
    batch.sort(key=lambda x: len(x[0]), reverse=True)
    
    # 分离数据和标签
    sequences, labels = zip(*batch)
    
    # 计算每个序列的长度
    lengths = [len(seq) for seq in sequences]
    
    # 填充序列到相同长度
    padded_seqs = torch.nn.utils.rnn.pad_sequence(sequences, batch_first=True)
    
    return padded_seqs, torch.tensor(labels), lengths

使用自定义的collate_fn

dataloader = DataLoader(
    dataset=text_dataset,
    batch_size=16,
    shuffle=True,
    collate_fn=collate_fn
)

2.使用Sampler进行不均衡数据采样
from torch.utils.data import WeightedRandomSampler

假设我们有类别不平衡问题,计算采样权重

class_count = [100, 1000, 500]  # 每个类别的样本数量
weights = 1.0 / torch.tensor(class_count, dtype=torch.float)
sample_weights = weights[target_list]  # target_list是每个样本的类别索引

创建WeightedRandomSampler

sampler = WeightedRandomSampler(
    weights=sample_weights,
    num_samples=len(sample_weights),
    replacement=True
)

使用sampler

dataloader = DataLoader(
    dataset=dataset,
    batch_size=32,
    sampler=sampler,  # 使用sampler时,shuffle必须为False
    num_workers=4
)

网站公告

今日签到

点亮在社区的每一天
去签到