import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
import json
train_images = np.load('../dataset/train_image.npy')
train_labels = np.load('../dataset/train_label_3.npy')
test_images = np.load('../dataset/test_image.npy')
test_labels = np.load('../dataset/test_label_3.npy')
train_labels = np.argmax(train_labels, axis=1)
test_labels = np.argmax(test_labels, axis=1)
train_images = (train_images * 255).astype(np.uint8)
test_images = (test_images * 255).astype(np.uint8)
class NumpyToPIL(object):
def __call__(self, sample):
return Image.fromarray(sample)
class CustomImageDataset(Dataset):
def __init__(self, images, labels, transform=None):
self.images = images
self.labels = labels
self.transform = transform
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
image = self.images[idx]
label = self.labels[idx]
if self.transform:
image = self.transform(image)
return image, label
transform_train = transforms.Compose([
NumpyToPIL(),
transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
transform_test = transforms.Compose([
NumpyToPIL(),
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
dataset_train = CustomImageDataset(train_images, train_labels, transform=transform_train)
dataset_test = CustomImageDataset(test_images, test_labels, transform=transform_test)
train_loader = DataLoader(dataset_train, batch_size=BATCH_SIZE, num_workers=8, shuffle=True, drop_last=True)
test_loader = DataLoader(dataset_test, batch_size=BATCH_SIZE, shuffle=False)
train_labels = train_labels.ravel()
test_labels = test_labels.ravel()
train_class_to_idx = {str(i): i for i in set(train_labels.tolist())}
test_class_to_idx = {str(i): i for i in set(test_labels.tolist())}
with open('train_class.txt', 'w') as file:
file.write(str(train_class_to_idx))
with open('train_class.json', 'w', encoding='utf-8') as file:
file.write(json.dumps(train_class_to_idx))
with open('test_class.txt', 'w') as file:
file.write(str(test_class_to_idx))
with open('test_class.json', 'w', encoding='utf-8') as file:
file.write(json.dumps(test_class_to_idx))