一、PyTorch数据加载核心组件
在PyTorch中,数据准备主要涉及两个核心类:Dataset和DataLoader。它们共同构成了PyTorch灵活高效的数据管道系统。
- Dataset类:
- 作为数据集的抽象基类,需要实现三个关键方法:
- len(): 返回数据集大小
- getitem(): 获取单个数据样本
- (可选) init(): 初始化逻辑
- 常见实现方式:
- 继承torch.utils.data.Dataset
- 使用TensorDataset处理张量数据
- 使用ImageFolder处理图像文件夹
- 示例场景:
class CustomDataset(Dataset): def __init__(self, data, labels): self.data = data self.labels = labels def __len__(self): return len(self.data) def __getitem__(self, idx): return self.data[idx], self.labels[idx]
2.DataLoader类:
- 主要功能:
- 批量加载数据
- 数据打乱(shuffle=True)
- 多进程数据加载
- 内存管理
- 关键参数:
- batch_size: 每批数据量
- shuffle: 是否随机打乱
- num_workers: 子进程数
- pin_memory: 加速GPU传输
- 典型使用方式:
loader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4) for batch in loader: # 训练逻辑
- 组合优势:
- 内存效率:仅加载当前需要的批次
- 灵活性:支持自定义数据转换
- 性能:多进程并行加载
- 标准化:统一数据访问接口
- 高级特性:
- Sampler
- 控制数据采样顺序
- 自定义collate_fn处理复杂批次结构
- 使用IterableDataset处理流式数据
这套数据管道系统使得PyTorch能够高效处理从GB到TB级别的各种数据集,是深度学习训练流程的重要基础组件。
1.1 Dataset类详解
Dataset
是一个抽象类,是所有自定义数据集应该继承的基类。它定义了数据集必须实现的方法:
from torch.utils.data import Dataset
class CustomDataset(Dataset):
def __init__(self, data, labels):
"""
初始化数据集
:param data: 样本数据(NumPy数组或PyTorch张量)
:param labels: 样本标签
"""
self.data = data
self.labels = labels
def __len__(self):
"""返回数据集的大小"""
return len(self.data)
def __getitem__(self, index):
"""
支持整数索引,返回对应的样本
:param index: 样本索引
:return: (样本数据, 标签)
"""
sample = self.data[index]
label = self.labels[index]
return sample, label
关键方法说明:
__init__
: 初始化方法,通常在这里加载数据或定义数据路径__len__
: 返回数据集大小,供DataLoader确定迭代次数__getitem__
: 根据索引返回样本,支持数据增强和转换
1.2 TensorDataset便捷类
当数据已经是张量形式时,可以使用TensorDataset
简化代码:
from torch.utils.data import TensorDataset
import torch
# 创建特征和标签张量
features = torch.randn(100, 5) # 100个样本,每个5个特征
labels = torch.randint(0, 2, (100,)) # 100个二进制标签
# 创建数据集
dataset = TensorDataset(features, labels)
# 查看第一个样本
print(dataset[0]) # 输出: (tensor([...]), tensor(0))
TensorDataset
源码分析
class TensorDataset(Dataset):
def __init__(self, *tensors):
assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors)
self.tensors = tensors
def __getitem__(self, index):
return tuple(tensor[index] for tensor in self.tensors)
def __len__(self):
return self.tensors[0].size(0)
二、DataLoader:高效数据加载引擎
DataLoader
是一个迭代器,负责从Dataset
中批量加载数据,并提供多种实用功能。
2.1 基本使用方法
from torch.utils.data import DataLoader
# 创建DataLoader
dataloader = DataLoader(
dataset, # 数据集对象
batch_size=32, # 批量大小
shuffle=True, # 是否在每个epoch打乱数据
num_workers=4, # 使用4个子进程加载数据
drop_last=False # 是否丢弃最后不完整的batch
)
# 遍历DataLoader
for batch_idx, (data, labels) in enumerate(dataloader):
print(f"Batch {batch_idx}:")
print("Data shape:", data.shape) # [batch_size, ...]
print("Labels shape:", labels.shape) # [batch_size]
2.2 关键参数详解
参数 | 类型 | 说明 | 默认值 |
---|---|---|---|
dataset |
Dataset | 要加载的数据集对象 | - |
batch_size |
int | 每个batch的样本数 | 1 |
shuffle |
bool | 是否在每个epoch开始时打乱数据 | False |
num_workers |
int | 用于数据加载的子进程数 | 0 |
drop_last |
bool | 是否丢弃最后一个不完整的batch | False |
pin_memory |
bool | 是否将数据复制到CUDA固定内存 | False |
collate_fn |
callable | 合并样本列表形成batch的函数 | None |
2.3 多进程加载原理
当num_workers > 0
时,DataLoader会使用多进程加速数据加载:
主进程创建
num_workers
个子进程每个子进程独立加载数据
通过共享内存或队列将数据传输给主进程
主进程将数据组装成batch
注意事项:
在Windows系统下需要将主要代码放在
if __name__ == '__main__':
中子进程会复制父进程的所有资源,可能导致内存问题
子进程中的随机状态可能与主进程不同
三、实战案例:不同类型数据加载
3.1 CSV数据加载
import pandas as pd
from torch.utils.data import Dataset
class CsvDataset(Dataset):
def __init__(self, file_path):
"""
加载CSV文件创建数据集
:param file_path: CSV文件路径
"""
df = pd.read_csv(file_path)
# 假设最后一列是标签,其余是特征
self.features = df.iloc[:, :-1].values
self.labels = df.iloc[:, -1].values
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
features = torch.FloatTensor(self.features[idx])
label = torch.LongTensor([self.labels[idx]])[0]
return features, label
# 使用示例
dataset = CsvDataset('data.csv')
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
3.2 图像数据加载
自定义图像数据集
import os
import cv2
from torchvision import transforms
class ImageDataset(Dataset):
def __init__(self, root_dir, transform=None):
"""
:param root_dir: 图片根目录,子目录名为类别名
:param transform: 图像变换组合
"""
self.root_dir = root_dir
self.transform = transform
self.classes = sorted(os.listdir(root_dir))
self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)}
self.samples = self._make_dataset()
def _make_dataset(self):
samples = []
for cls in self.classes:
cls_dir = os.path.join(self.root_dir, cls)
for img_name in os.listdir(cls_dir):
img_path = os.path.join(cls_dir, img_name)
samples.append((img_path, self.class_to_idx[cls]))
return samples
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
img_path, label = self.samples[idx]
# 使用OpenCV读取图像(BGR格式)
img = cv2.imread(img_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # 转换为RGB
if self.transform:
img = self.transform(img)
return img, torch.tensor(label)
# 定义图像变换
transform = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
# 使用示例
dataset = ImageDataset('images/', transform=transform)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)
使用torchvision的ImageFolder
对于标准图像分类数据集,可以使用ImageFolder
简化流程:
from torchvision.datasets import ImageFolder
from torchvision import transforms
# 定义变换
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# 加载数据集
dataset = ImageFolder(root='path/to/data', transform=transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
# 查看类别映射
print(dataset.class_to_idx) # 输出: {'cat': 0, 'dog': 1}
3.3 官方数据集加载
PyTorch提供了多种常用数据集的便捷加载方式:
from torchvision import datasets, transforms
# MNIST手写数字数据集
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
train_set = datasets.MNIST(
root='./data',
train=True,
download=True,
transform=transform
)
test_set = datasets.MNIST(
root='./data',
train=False,
download=True,
transform=transform
)
# CIFAR-10数据集
transform = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32, padding=4),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])
train_set = datasets.CIFAR10(
root='./data',
train=True,
download=True,
transform=transform
)
四、高级技巧与最佳实践
4.1 数据增强策略
from torchvision import transforms
# 训练集变换(包含数据增强)
train_transform = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
transforms.RandomRotation(15),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# 测试集变换(仅标准化)
test_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
4.2 自定义collate_fn
当默认的batch组装方式不满足需求时,可以自定义collate_fn
:
def custom_collate(batch):
# batch是包含多个__getitem__返回值的列表
# 例如对于图像分割任务,可能有图像和对应的mask
images, masks = zip(*batch)
# 对图像进行padding使其大小一致
max_h = max(img.shape[1] for img in images)
max_w = max(img.shape[2] for img in images)
padded_images = []
for img in images:
pad_h = max_h - img.shape[1]
pad_w = max_w - img.shape[2]
padded_img = torch.nn.functional.pad(img, (0, pad_w, 0, pad_h))
padded_images.append(padded_img)
return torch.stack(padded_images), torch.stack(masks)
# 使用自定义collate_fn
dataloader = DataLoader(dataset, batch_size=4, collate_fn=custom_collate)
4.3 内存优化技巧
使用DALI加速:NVIDIA Data Loading Library (DALI)可以极大加速数据加载
预取数据:设置
DataLoader
的prefetch_factor
参数pin_memory:在GPU训练时设置
pin_memory=True
加速CPU到GPU的数据传输避免重复转换:对静态数据预先进行转换,而不是在
__getitem__
中转换
dataloader = DataLoader(
dataset,
batch_size=64,
num_workers=4,
pin_memory=True,
prefetch_factor=2
)
五、常见问题与解决方案
5.1 数据加载瓶颈诊断
使用PyTorch Profiler检测数据加载是否成为瓶颈:
with torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CPU],
schedule=torch.profiler.schedule(wait=1, warmup=1, active=3)
) as prof:
for i, (inputs, targets) in enumerate(dataloader):
if i >= (1 + 1 + 3): break
prof.step()
print(prof.key_averages().table(sort_by="self_cpu_time_total"))
5.2 内存不足问题
解决方案:
减小
batch_size
使用
torch.utils.data.Subset
加载部分数据使用
Dataloader
的persistent_workers=True
选项(PyTorch 1.7+)使用内存映射文件处理大型数据集
5.3 多GPU训练数据分割
使用DistributedSampler
确保每个GPU获取不同的数据分片:
from torch.utils.data.distributed import DistributedSampler
sampler = DistributedSampler(dataset, shuffle=True)
dataloader = DataLoader(
dataset,
batch_size=64,
sampler=sampler,
num_workers=4
)
六、总结
PyTorch的数据加载系统提供了灵活高效的API来处理各种类型的数据。通过合理使用Dataset
和DataLoader
,结合数据增强和内存优化技巧,可以构建出满足不同需求的数据管道。关键点包括:
根据数据类型选择合适的
Dataset
实现方式合理配置
DataLoader
参数,特别是batch_size
和num_workers
使用数据增强提高模型泛化能力
针对特定任务自定义
collate_fn
监控数据加载性能,避免成为训练瓶颈
掌握这些数据准备技术,将为后续的模型训练打下坚实基础。