【PYG】pyg里dataloader和torch中dataloader有什么不一样

发布于:2024-07-01 ⋅ 阅读:(17) ⋅ 点赞:(0)

torch.utils.data.DataLoadertorch_geometric.loader.DataLoader 是两个不同的加载器,它们分别用于处理不同类型的数据。以下是它们之间的主要区别:

torch.utils.data.DataLoader

torch.utils.data.DataLoader 是 PyTorch 中的通用数据加载器,用于加载任何遵循 torch.utils.data.Dataset 接口的数据集。它主要用于加载图像、文本和其他常见的数据类型。关键特性包括:

  • 通用性:适用于所有遵循 Dataset 接口的数据集。
  • 批量加载:支持批量加载数据,并行处理,数据打乱等。
  • 数据增强:可以使用 transform 进行数据增强和预处理。
  • 自定义 collate_fn:允许自定义数据批量处理函数。

torch_geometric.loader.DataLoader

torch_geometric.loader.DataLoader 是 PyTorch Geometric (PyG) 提供的数据加载器,专门用于加载图数据。它与 torch.utils.data.DataLoader 类似,但具有一些针对图数据的特性和优化。关键特性包括:

  • 图数据支持:直接支持 PyG 中的 DataBatch 对象,处理图的节点特征、边索引和其他属性。
  • 批量处理图数据:可以将多个图数据对象合并为一个批次,处理不同图的批量操作。
  • 支持稀疏表示:适合处理稀疏图结构,利用 PyG 的稀疏矩阵表示。
  • 自定义批处理:可以自定义 collate_fn 以处理复杂的批处理逻辑。

示例代码

使用 torch.utils.data.DataLoader

这是一个通用的 DataLoader 示例,适用于非图数据。

import torch
from torch.utils.data import DataLoader, Dataset

class MyDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

# 创建一些示例数据
data = torch.randn(100, 10)
labels = torch.randint(0, 2, (100,))

# 创建数据集
dataset = MyDataset(data, labels)

# 使用 DataLoader 加载数据
loader = DataLoader(dataset, batch_size=32, shuffle=True)

# 迭代加载数据
for batch_data, batch_labels in loader:
    print("Batch data shape:", batch_data.shape)
    print("Batch labels shape:", batch_labels.shape)
使用 torch_geometric.loader.DataLoader

这是一个用于加载图数据的示例。

import torch
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch.utils.data import Dataset

class MyGraphDataset(Dataset):
    def __init__(self, num_samples, num_nodes, num_node_features):
        self.num_samples = num_samples
        self.num_nodes = num_nodes
        self.num_node_features = num_node_features

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        x = torch.randn(self.num_nodes, self.num_node_features)
        edge_index = torch.tensor([[i, (i + 1) % self.num_nodes] for i in range(self.num_nodes)], dtype=torch.long).t().contiguous()
        y = torch.randn(self.num_nodes, 1)
        return Data(x=x, edge_index=edge_index, y=y)

# 创建图数据集
num_samples = 100
num_nodes = 10
num_node_features = 8
dataset = MyGraphDataset(num_samples, num_nodes, num_node_features)

# 使用 PyG DataLoader 加载图数据
loader = DataLoader(dataset, batch_size=32, shuffle=True)

# 迭代加载图数据
for batch in loader:
    print("Batch node features shape:", batch.x.shape)
    print("Batch edge index shape:", batch.edge_index.shape)

主要区别

  • 数据类型torch.utils.data.DataLoader 适用于通用数据类型,torch_geometric.loader.DataLoader 专门用于图数据。
  • 批处理方式torch.utils.data.DataLoader 处理通用张量数据,torch_geometric.loader.DataLoader 处理图数据并支持将多个图合并为一个批次。
  • 自定义能力:两者都支持自定义 collate_fn,但 torch_geometric.loader.DataLoadercollate_fn 主要用于处理图数据的合并和批处理。

根据你的具体需求选择合适的数据加载器。如果处理的是图数据,推荐使用 torch_geometric.loader.DataLoader。对于其他类型的数据,可以使用 torch.utils.data.DataLoader