PyTorch 的 DataLoader 是数据加载的核心组件,它能高效地批量加载数据并进行预处理。
Pytorch DataLoader基础概念
DataLoader基础概念
DataLoader是PyTorch基础概念
DataLoader是PyTorch中用于加载数据的工具,它可以:批量加载数据(batch loading)打乱数据(shuffling)并行加载数据(多线程)
自定义数据加载方式Dataloader的基本使用from torch.utils.data import Dataset, DataLoader
自定义数据集类
class MyDataset(Dataset):
def __init__(self, data, labels):
self.data = data
self.labels = labels
def __getitem__(self, index):
return self.data[index], self.labels[index]
def __len__(self):
return len(self.data)
创建数据集实例
dataset = MyDataset(data, labels)
创建DataLoader
dataloader = DataLoader(
dataset=dataset, # 数据集
batch_size=32, # 批次大小
shuffle=True, # 是否打乱数据
num_workers=4, # 多进程加载数据的线程数
drop_last=False # 当样本数不能被batch_size整除时,是否丢弃最后一个不完整的batch
)
# 使用DataLoader迭代数据
for batch_data, batch_labels in dataloader:
# 训练或推理代码
pass
DataLoader重要参数详解
- dataset: 要加载的数据集,必须是Dataset类的实例 batch_size: 每个批次的样本数
- shuffle:是否在每个epoch重新打乱数据
- sampler:自定义从数据集中抽取样本的策略,如果指定了sampler,则shuffle必须为False
- num_workers:使用多少个子进程加载数据,0表示在主进程中加载。
- collate_fn:将一批数据整合成一个批次的函数,特别使用于处理不同长度的序列数据
- Pin_memory:如果为True,数据加载器会将张量复制到CUDA固定内存中,加速CPU到GPU的数据传输
- drop_last: 如果数据集大小不能被batch_size整除,是否丢弃最后一个不完整的批次。
- timeout:收集一个批次的超时值
- worker_init_fn:每个worker初始化时被调用的函数
- weight_sampler:参数决定是都使用加权采样器来平衡类别分布
if infinite_data_loader:
data_loader = InfiniteDataLoader(
dataset=data,
batch_size=batch_size,
shuffle=shuffle,
num_workers=num_workers,
sampler=sampler,
**kwargs
)
else:
data_loader = DataLoader(
dataset=data,
batch_size=batch_size,
shuffle=shuffle,
num_workers=num_workers,
sampler=sampler,
**kwargs
)
n_class = len(data.classes)
return data_loader, n_class
这段代码决定了如何创建数据加载器,根据infinite_data_loader参数选择不同的加载器类型:
if infinite_data_loader:
data_loader = InfiniteDataLoader(
dataset=data,
batch_size=batch_size,
shuffle=shuffle,
num_workers=num_workers,
sampler=sampler,
**kwargs
)
else:
data_loader = DataLoader(
dataset=data,
batch_size=batch_size,
shuffle=shuffle,
num_workers=num_workers,
sampler=sampler,
**kwargs
)
n_class = len(data.classes)
return data_loader, n_class
代码解析
这段代码基于infinite_data_loader参数创建不同类型的数据加载器:
当infinite_data_loader为True时:
创建InfiniteDataLoader实例
自定义的无限循环数据加载器,会持续提供数据而不会在一个epoch结束时停止
当infinite_data_loader为False时:
创建标准的PyTorch DataLoader实例
这是普通的数据加载器,一个epoch结束后会停止
共同参数:
dataset=data:要加载的数据集
batch_size=batch_size:每批数据的大小
shuffle=shuffle:是否打乱数据(之前代码中已设置)
num_workers=num_workers:用于并行加载数据的线程数
sampler=sampler:用于采样的策略(之前代码中已设置,可能是加权采样器)
**kwargs:其他可能的参数,如pin_memory、drop_last等
返回值:
data_loader:创建好的数据加载器
n_class = len(data.classes):数据集中的类别数量
InfiniteDataLoader的作用
在您的代码中定义了两种InfiniteDataLoader实现:一种作为DataLoader的子类,另一种是完全自定义的类。它们的共同目的是:
持续提供数据:当一个epoch结束后,自动重新开始,不会引发StopIteration异常
支持长时间训练:在需要长时间训练的场景中特别有用,如半监督学习或者领域适应
避免手动重置:不需要在每个epoch结束后手动重置数据加载器
使用场景
无限数据加载器特别适用于:
持续训练:模型需要无限期地训练,如自监督学习或强化学习
不均匀更新:源域和目标域数据需要不同频率的更新
流式训练:数据以流的形式到达,不需要明确的epoch边界
基于迭代而非epoch的训练:训练基于迭代次数而非数据epoch
最后的返回值n_class提供了数据集的类别数量,这对模型构建和评估都很重要,比如设置分类层的输出维度或计算平均类别准确率。
高级用法
1.自定义collate_fn处理变长序列
def collate_fn(batch):
# 排序批次数据,按序列长度降序
batch.sort(key=lambda x: len(x[0]), reverse=True)
# 分离数据和标签
sequences, labels = zip(*batch)
# 计算每个序列的长度
lengths = [len(seq) for seq in sequences]
# 填充序列到相同长度
padded_seqs = torch.nn.utils.rnn.pad_sequence(sequences, batch_first=True)
return padded_seqs, torch.tensor(labels), lengths
使用自定义的collate_fn
dataloader = DataLoader(
dataset=text_dataset,
batch_size=16,
shuffle=True,
collate_fn=collate_fn
)
2.使用Sampler进行不均衡数据采样
from torch.utils.data import WeightedRandomSampler
假设我们有类别不平衡问题,计算采样权重
class_count = [100, 1000, 500] # 每个类别的样本数量
weights = 1.0 / torch.tensor(class_count, dtype=torch.float)
sample_weights = weights[target_list] # target_list是每个样本的类别索引
创建WeightedRandomSampler
sampler = WeightedRandomSampler(
weights=sample_weights,
num_samples=len(sample_weights),
replacement=True
)
使用sampler
dataloader = DataLoader(
dataset=dataset,
batch_size=32,
sampler=sampler, # 使用sampler时,shuffle必须为False
num_workers=4
)