【Pytorch】18.创建自定义数据集并根据文件名或对应文件名的文本文件获取labels

发布于:2024-05-24 ⋅ 阅读:(66) ⋅ 点赞:(0)

源码

MNIST_Training_By_FileName_Dataset
MNIST_Training_By_TXTLabel

简介

本文主要探讨两种不同的数据集获取labels的方法

  • 根据图片的文件名中获取文件标签
    在这里插入图片描述

  • 根据与图片名称相同的.txt文件获取文件名
    在这里插入图片描述

根据图片名称获取labels

主要的区别在__init__方法中

    def __init__(self, root_path, train, transform=None):
        self.root_path = root_path
        self.transform = transform
        if train:
            self.root_path = os.path.join(self.root_path, 'training')
        else:
            self.root_path = os.path.join(self.root_path, 'testing')

        self.img_paths = []
        self.labels = []
        for label_path in os.listdir(self.root_path):
            img_path = os.path.join(self.root_path, label_path)
            if os.path.isdir(img_path):
                for img in os.listdir(img_path):
                    # 使用正则获取图片名称中的信息
                    match = re.search(r'_(\d+)', img)
                    label = match.group(1)
                    # print(f'label: {label}')
                    pre_img_path = os.path.join(img_path, img)
                    self.img_paths.append(pre_img_path)
                    self.labels.append(label)

我们可以看到在我们获取图片名称后,我们需要使用正则化来提取文件名中含有的label:xxx_0.png

根据txt文件获取labels

主要的区别在__getitem__方法


    def __getitem__(self, index):
        # ../datasets/mnist_png/training/.../1.png
        img = self.imgs[index]

        # 仅获取文件名
        # 1.png
        img_name = os.path.basename(img)
        img = Image.open(img).convert('L')
        if self.transform is not None:
            img = self.transform(img)
        # ../datasets/mnist_png/labels/1.txt
        label_dir = os.path.join(self.label_path, img_name.replace('.png', '.txt'))
        # 从文件中获取内容
        with open(label_dir, 'r') as f:
            label = f.read().strip()
        return img, label

  1. 我们需要现根据图片的相对路径通过os.path.basename来获取文件名
  2. 然后根据图片名使用img_name.replace来将.png换成.txt然后在对应的labels文件夹下找到对应名称的文件来获取标签