.npy格式图像如何进行深度学习模型训练处理,亲测可行

发布于:2024-07-04 ⋅ 阅读:(26) ⋅ 点赞:(0)
import torch
    import torch.nn as nn
    import torch.nn.functional as F
    import numpy as np
    from torch.utils.data import DataLoader, Dataset
    from torchvision import transforms
    from PIL import Image
    import json
    # 加载训练集和测试集数据
    train_images = np.load('../dataset/train_image.npy')
    train_labels = np.load('../dataset/train_label_3.npy')
    test_images = np.load('../dataset/test_image.npy')
    test_labels = np.load('../dataset/test_label_3.npy')

    # 将one-hot编码的标签转换为整数索引
    train_labels = np.argmax(train_labels, axis=1)
    test_labels = np.argmax(test_labels, axis=1)

    # 确保图像数据是 uint8 类型
    train_images = (train_images * 255).astype(np.uint8)
    test_images = (test_images * 255).astype(np.uint8)


    # 定义数据集类
    class NumpyToPIL(object):
        def __call__(self, sample):
            return Image.fromarray(sample)


    class CustomImageDataset(Dataset):
        def __init__(self, images, labels, transform=None):
            self.images = images
            self.labels = labels
            self.transform = transform

        def __len__(self):
            return len(self.images)

        def __getitem__(self, idx):
            image = self.images[idx]
            label = self.labels[idx]

            if self.transform:
                image = self.transform(image)

            return image, label


    # 数据预处理和增强
    transform_train = transforms.Compose([
        NumpyToPIL(),
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    transform_test = transforms.Compose([
        NumpyToPIL(),
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    # 创建数据集和数据加载器
    #BATCH_SIZE = 32

    dataset_train = CustomImageDataset(train_images, train_labels, transform=transform_train)
    dataset_test = CustomImageDataset(test_images, test_labels, transform=transform_test)

    train_loader = DataLoader(dataset_train, batch_size=BATCH_SIZE, num_workers=8, shuffle=True, drop_last=True)
    test_loader = DataLoader(dataset_test, batch_size=BATCH_SIZE, shuffle=False)

    # 检查标签格式
    train_labels = train_labels.ravel()
    test_labels = test_labels.ravel()

    # 检查标签的唯一值,生成类别索引映射
    train_class_to_idx = {str(i): i for i in set(train_labels.tolist())}
    test_class_to_idx = {str(i): i for i in set(test_labels.tolist())}

    with open('train_class.txt', 'w') as file:
        file.write(str(train_class_to_idx))
    with open('train_class.json', 'w', encoding='utf-8') as file:
        file.write(json.dumps(train_class_to_idx))

    with open('test_class.txt', 'w') as file:
        file.write(str(test_class_to_idx))
    with open('test_class.json', 'w', encoding='utf-8') as file:
        file.write(json.dumps(test_class_to_idx))