源码
MNIST_Training_By_FileName_Dataset
MNIST_Training_By_TXTLabel
简介
本文主要探讨两种不同的数据集获取labels
的方法
根据图片的文件名中获取文件标签
根据与图片名称相同的
.txt
文件获取文件名
根据图片名称获取labels
主要的区别在__init__
方法中
def __init__(self, root_path, train, transform=None):
self.root_path = root_path
self.transform = transform
if train:
self.root_path = os.path.join(self.root_path, 'training')
else:
self.root_path = os.path.join(self.root_path, 'testing')
self.img_paths = []
self.labels = []
for label_path in os.listdir(self.root_path):
img_path = os.path.join(self.root_path, label_path)
if os.path.isdir(img_path):
for img in os.listdir(img_path):
# 使用正则获取图片名称中的信息
match = re.search(r'_(\d+)', img)
label = match.group(1)
# print(f'label: {label}')
pre_img_path = os.path.join(img_path, img)
self.img_paths.append(pre_img_path)
self.labels.append(label)
我们可以看到在我们获取图片名称后,我们需要使用正则化来提取文件名中含有的label:xxx_0.png
根据txt文件获取labels
主要的区别在__getitem__
方法
def __getitem__(self, index):
# ../datasets/mnist_png/training/.../1.png
img = self.imgs[index]
# 仅获取文件名
# 1.png
img_name = os.path.basename(img)
img = Image.open(img).convert('L')
if self.transform is not None:
img = self.transform(img)
# ../datasets/mnist_png/labels/1.txt
label_dir = os.path.join(self.label_path, img_name.replace('.png', '.txt'))
# 从文件中获取内容
with open(label_dir, 'r') as f:
label = f.read().strip()
return img, label
- 我们需要现根据图片的相对路径通过
os.path.basename
来获取文件名 - 然后根据图片名使用
img_name.replace
来将.png
换成.txt
然后在对应的labels文件夹下找到对应名称的文件来获取标签