19.数据增强技术

发布于:2025-07-18 ⋅ 阅读:(12) ⋅ 点赞:(0)

19.1 图像水平翻转与垂直翻转

import torch
import torchvision
import matplotlib.pyplot as plt
from PIL import Image
def apply(img, aug, num_rows=2, num_cols=4, scale=1.5):
    Y = [aug(img) for _ in range(num_rows * num_cols)]
    figsize = (num_cols * scale, num_rows * scale)
    fig, axes = plt.subplots(num_rows, num_cols, figsize=figsize)
    axes = axes.flatten()  
    for ax, y in zip(axes, Y):
        ax.imshow(y)
        ax.axis('off')
    plt.tight_layout()
    plt.show()
apply(image,torchvision.transforms.RandomHorizontalFlip())#左右翻转
apply(image,torchvision.transforms.RandomVerticalFlip())#上下翻转

在这里插入图片描述

19.1 图像随机裁剪与色彩调整

import torch
import torchvision
import matplotlib.pyplot as plt
from PIL import Image
def apply(img, aug, num_rows=2, num_cols=4, scale=1.5):
    Y = [aug(img) for _ in range(num_rows * num_cols)]
    figsize = (num_cols * scale, num_rows * scale)
    fig, axes = plt.subplots(num_rows, num_cols, figsize=figsize)
    axes = axes.flatten()  
    for ax, y in zip(axes, Y):
        ax.imshow(y)
        ax.axis('off')
    plt.tight_layout()
    plt.show()
apply(image,torchvision.transforms.RandomResizedCrop(size=(200,200),scale=(0.2,1),ratio=(1,2)))#随机裁剪
apply(image,torchvision.transforms.ColorJitter(brightness=0.5,contrast=0.5,saturation=0.5,hue=0.5))#明亮度/对比度等调整

在这里插入图片描述

19.3 图像整体增强变换

import torch
import torchvision
import matplotlib.pyplot as plt
from PIL import Image
def apply(img, aug, num_rows=2, num_cols=4, scale=1.5):
    Y = [aug(img) for _ in range(num_rows * num_cols)]
    figsize = (num_cols * scale, num_rows * scale)
    fig, axes = plt.subplots(num_rows, num_cols, figsize=figsize)
    axes = axes.flatten()  
    for ax, y in zip(axes, Y):
        ax.imshow(y)
        ax.axis('off')
    plt.tight_layout()
    plt.show()
loc_aug=torchvision.transforms.RandomHorizontalFlip()
shape_aug = torchvision.transforms.RandomResizedCrop((200, 200), scale=(0.1, 1), ratio=(0.5, 2))
color_aug = torchvision.transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5)
augs = torchvision.transforms.Compose([loc_aug,color_aug, shape_aug])
apply(image, augs)

在这里插入图片描述

19.4 基于CiFar-10数据集的图像增强效果对比

在这里插入图片描述

################################################################################################################
#ResNet
################################################################################################################
import torch
import torchvision
from torch import nn
import matplotlib.pyplot as plt
from torchvision.transforms import transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
import torchvision.models as models
from sklearn.metrics import accuracy_score
from torch.nn import functional as F
plt.rcParams['font.family']=['Times New Roman']
class Reshape(torch.nn.Module):
    def forward(self,x):
        return x.view(-1,1,28,28)#[bs,1,28,28]
def plot_metrics(train_loss_list, train_acc_list, test_acc_list, title='Training Curve'):
    epochs = range(1, len(train_loss_list) + 1)
    plt.figure(figsize=(4, 3))
    plt.plot(epochs, train_loss_list, label='Train Loss')
    plt.plot(epochs, train_acc_list, label='Train Acc',linestyle='--')
    plt.plot(epochs, test_acc_list, label='Test Acc', linestyle='--')
    plt.xlabel('Epoch')
    plt.ylabel('Value')
    plt.title(title)
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.show()
def train_model(model,train_data,test_data,num_epochs):
    train_loss_list = []
    train_acc_list = []
    test_acc_list = []
    for epoch in range(num_epochs):
        total_loss=0
        total_acc_sample=0
        total_samples=0
        loop1=tqdm(train_data,desc=f"EPOCHS[{epoch+1}/{num_epochs}]")
        for X,y in loop1:
            #X=X.reshape(X.shape[0],-1)
            #print(X.shape)
            X=X.to(device)
            y=y.to(device)
            y_hat=model(X)
            loss=CEloss(y_hat,y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            #loss累加
            total_loss+=loss.item()*X.shape[0]
            y_pred=y_hat.argmax(dim=1).detach().cpu().numpy()
            y_true=y.detach().cpu().numpy()
            total_acc_sample+=accuracy_score(y_pred,y_true)*X.shape[0]#保存样本数
            total_samples+=X.shape[0]
        test_acc_samples=0
        test_samples=0
        loop2=tqdm(test_data,desc=f"EPOCHS[{epoch+1}/{num_epochs}]")
        for X,y in loop2:
            X=X.to(device)
            y=y.to(device)
            #X=X.reshape(X.shape[0],-1)
            y_hat=model(X)
            y_pred=y_hat.argmax(dim=1).detach().cpu().numpy()
            y_true=y.detach().cpu().numpy()
            test_acc_samples+=accuracy_score(y_pred,y_true)*X.shape[0]#保存样本数
            test_samples+=X.shape[0]
        avg_train_loss=total_loss/total_samples
        avg_train_acc=total_acc_sample/total_samples
        avg_test_acc=test_acc_samples/test_samples
        train_loss_list.append(avg_train_loss)
        train_acc_list.append(avg_train_acc)
        test_acc_list.append(avg_test_acc)
        print(f"Epoch {epoch+1}: Loss: {avg_train_loss:.4f},Trian Accuracy: {avg_train_acc:.4f},test Accuracy: {avg_test_acc:.4f}")
    plot_metrics(train_loss_list, train_acc_list, test_acc_list)
    return model
def init_weights(m):
    if type(m) == nn.Linear or type(m) == nn.Conv2d:
        nn.init.xavier_uniform_(m.weight)
################################################################################################################
#这里选取一个是翻转,一个是归一化,一个是调整明亮度,最后是tensor化
transforms_train=transforms.Compose([transforms.RandomHorizontalFlip(),
                               transforms.ColorJitter(brightness=0.5),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5,),(0.5,))])
transforms_test=transforms.Compose([transforms.RandomHorizontalFlip(),
                               transforms.ColorJitter(brightness=0.5),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5,),(0.5,))])
train_img=torchvision.datasets.CIFAR10(root="./data",train=True,transform=transforms_train,download=True)
test_img=torchvision.datasets.CIFAR10(root="./data",train=False,transform=transforms_test,download=True)
train_data=DataLoader(train_img,batch_size=128,num_workers=4,shuffle=True)
test_data=DataLoader(test_img,batch_size=128,num_workers=4,shuffle=False)
################################################################################################################
device=torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
model=models.resnet50(pretrained=True)#直接调用ResNet-50进行训练
model.fc=nn.Linear(model.fc.in_features,10)
model.to(device)
model.apply(init_weights)
optimizer=torch.optim.SGD(model.parameters(),lr=0.05,momentum=0.9)
CEloss=nn.CrossEntropyLoss()
model=train_model(model,train_data,test_data,num_epochs=20)
################################################################################################################

在这里插入图片描述


网站公告

今日签到

点亮在社区的每一天
去签到