[AI]从零开始的YOLO数据集增强教程

发布于:2025-06-19 ⋅ 阅读:(14) ⋅ 点赞:(0)

一、前言

        不知道大家在训练YOLO时有没有遇到过这样的情况,明明数据集已经准备了很多了,但是YOLO还是不认识某个物品,或者置信度低。那么有没有办法让我们不制作新数据集的情况下让代码帮我们生成新的数据集来训练模型呢?当然有,并且现在最主流的办法就是将原本的图像进行翻转,改变亮度,以及添加噪声等。经过了这些步骤,就增加了我们数据集的多样性,相当于增加了YOLO的样本数量,这样,YOLO模型就能够认识更多样的对象,从而实现数据集增强。那么本次教程,就来教大家如何使用简单的处理代码实现对YOLO数据集的增强!

二、需要准备什么?

        既然需要对YOLO的训练数据集进行增强,这里需要大已经安装好YOLO环境并且对YOLO的训练非常熟悉。如果你还没有安装好YOLO的推理环境可以直接看下面的教程:

YOLO环境搭建:[AI]小白向的YOLO安装教程-CSDN博客

如果你还不会训练YOLO模型可以看下面的教程:

YOLO模型训练:[AI]YOLO如何训练对象检测模型(详细)_yolo模型-CSDN博客

当部署好YOLO环境并且对YOLO推理非常熟悉以后就可以进行下面的步骤了。

三、YOLO数据集增强

        这里我们需要对数据集进行增强,首先我们需要一个已经制作好的数据集,这里数据集的数量不用太多,我这里就准备了200张已经框好的数据集,用于识别花卉的碳黑病:

准备好对应的数据集,我们创新一个名为“Augment.py”的文件,然后把下方的代码粘贴进这个py文件中,如图所示:

# -*- coding: utf-8 -*-

import torch
from PIL import Image
from PIL import ImageDraw
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
from torchvision import transforms
import numpy as np
import matplotlib.pyplot as plt
import os
import random
random.seed(0)
 
 
class DataAugmentationOnDetection:
    def __init__(self):
        super(DataAugmentationOnDetection, self).__init__()
 
    # 以下的几个参数类型中,image的类型全部如下类型
    # 参数类型: image:Image.open(path)
    def resize_keep_ratio(self, image, boxes, target_size):
        """
            参数类型: image:Image.open(path), boxes:Tensor, target_size:int
            功能:将图像缩放到size尺寸,调整相应的boxes,同时保持长宽比(最长的边是target size
        """
        old_size = image.size[0:2]  # 原始图像大小
        # 取最小的缩放比例
        ratio = min(float(target_size) / (old_size[i]) for i in range(len(old_size)))  # 计算原始图像宽高与目标图像大小的比例,并取其中的较小值
        new_size = tuple([int(i * ratio) for i in old_size])  # 根据上边求得的比例计算在保持比例前提下得到的图像大小
        # boxes 不用变化,因为是等比例变化
        return image.resize(new_size, Image.BILINEAR), boxes
 
    def resizeDown_keep_ratio(self, image, boxes, target_size):
        """ 与上面的函数功能类似,但它只降低图片的尺寸,不会扩大图片尺寸"""
        old_size = image.size[0:2]  # 原始图像大小
        # 取最小的缩放比例
        ratio = min(float(target_size) / (old_size[i]) for i in range(len(old_size)))  # 计算原始图像宽高与目标图像大小的比例,并取其中的较小值
        ratio = min(ratio, 1)
        new_size = tuple([int(i * ratio) for i in old_size])  # 根据上边求得的比例计算在保持比例前提下得到的图像大小
 
        # boxes 不用变化,因为是等比例变化
        return image.resize(new_size, Image.BILINEAR), boxes
 
    def resize(self, img, boxes, size):
        # ---------------------------------------------------------
        # 类型为 img=Image.open(path),boxes:Tensor,size:int
        # 功能为:将图像长和宽缩放到指定值size,并且相应调整boxes
        # ---------------------------------------------------------
        return img.resize((size, size), Image.BILINEAR), boxes
 
    def random_flip_horizon(self, img, boxes, h_rate=1):
        # -------------------------------------
        # 随机水平翻转
        # -------------------------------------
        if np.random.random() < h_rate:
            transform = transforms.RandomHorizontalFlip(p=1)
            img = transform(img)
            if len(boxes) > 0:
                x = 1 - boxes[:, 1]
                boxes[:, 1] = x
        return img, boxes
 
    def random_flip_vertical(self, img, boxes, v_rate=1):
        # 随机垂直翻转
        if np.random.random() < v_rate:
            transform = transforms.RandomVerticalFlip(p=1)
            img = transform(img)
            if len(boxes) > 0:
                y = 1 - boxes[:, 2]
                boxes[:, 2] = y
        return img, boxes
 
    def center_crop(self, img, boxes, target_size=None):
        # -------------------------------------
        # 中心裁剪 ,裁剪成 (size, size) 的正方形, 仅限图形,w,h
        # 这里用比例是很难算的,转成x1,y1, x2, y2格式来计算
        # -------------------------------------
        w, h = img.size
        size = min(w, h)
        if len(boxes) > 0:
            # 转换到xyxy格式
            label = boxes[:, 0].reshape([-1, 1])
            x_, y_, w_, h_ = boxes[:, 1], boxes[:, 2], boxes[:, 3], boxes[:, 4]
            x1 = (w * x_ - 0.5 * w * w_).reshape([-1, 1])
            y1 = (h * y_ - 0.5 * h * h_).reshape([-1, 1])
            x2 = (w * x_ + 0.5 * w * w_).reshape([-1, 1])
            y2 = (h * y_ + 0.5 * h * h_).reshape([-1, 1])
            boxes_xyxy = torch.cat([x1, y1, x2, y2], dim=1)
            # 边框转换
            if w > h:
                boxes_xyxy[:, [0, 2]] = boxes_xyxy[:, [0, 2]] - (w - h) / 2
            else:
                boxes_xyxy[:, [1, 3]] = boxes_xyxy[:, [1, 3]] - (h - w) / 2
            in_boundary = [i for i in range(boxes_xyxy.shape[0])]
            for i in range(boxes_xyxy.shape[0]):
                # 判断x是否超出界限
                if (boxes_xyxy[i, 0] < 0 and boxes_xyxy[i, 2] < 0) or (boxes_xyxy[i, 0] > size and boxes_xyxy[i, 2] > size):
                    in_boundary.remove(i)
                # 判断y是否超出界限
                elif (boxes_xyxy[i, 1] < 0 and boxes_xyxy[i, 3] < 0) or (boxes_xyxy[i, 1] > size and boxes_xyxy[i, 3] > size):
                    in_boundary.append(i)
            boxes_xyxy = boxes_xyxy[in_boundary]
            boxes = boxes_xyxy.clamp(min=0, max=size).reshape([-1, 4])  # 压缩到固定范围
            label = label[in_boundary]
            # 转换到YOLO格式
            x1, y1, x2, y2 = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3]
            xc = ((x1 + x2) / (2 * size)).reshape([-1, 1])
            yc = ((y1 + y2) / (2 * size)).reshape([-1, 1])
            wc = ((x2 - x1) / size).reshape([-1, 1])
            hc = ((y2 - y1) / size).reshape([-1, 1])
            boxes = torch.cat([xc, yc, wc, hc], dim=1)
        # 图像转换
        transform = transforms.CenterCrop(size)
        img = transform(img)
        if target_size:
            img = img.resize((target_size, target_size), Image.BILINEAR)
        if len(boxes) > 0:
            return img, torch.cat([label.reshape([-1, 1]), boxes], dim=1)
        else:
            return img, boxes
 
    # ------------------------------------------------------
    # 以下img皆为Tensor类型
    # ------------------------------------------------------
 
    def random_bright(self, img, u=120, p=1):
        # -------------------------------------
        # 随机亮度变换
        # -------------------------------------
        if np.random.random() < p:
            alpha=np.random.uniform(-u, u)/255
            img += alpha
            img=img.clamp(min=0.0, max=1.0)
        return img
 
    def random_contrast(self, img, lower=0.5, upper=1.5, p=1):
        # -------------------------------------
        # 随机增强对比度
        # -------------------------------------
        if np.random.random() < p:
            alpha=np.random.uniform(lower, upper)
            img*=alpha
            img=img.clamp(min=0, max=1.0)
        return img
 
    def random_saturation(self, img,lower=0.5, upper=1.5, p=1):
        # 随机饱和度变换,针对彩色三通道图像,中间通道乘以一个值
        if np.random.random() < p:
            alpha=np.random.uniform(lower, upper)
            img[1]=img[1]*alpha
            img[1]=img[1].clamp(min=0,max=1.0)
        return img
 
    def add_gasuss_noise(self, img, mean=0, std=0.1):
        noise=torch.normal(mean,std,img.shape)
        img+=noise
        img=img.clamp(min=0, max=1.0)
        return img
 
    def add_salt_noise(self, img):
        noise=torch.rand(img.shape)
        alpha=np.random.random()/5 + 0.7
        img[noise[:,:,:]>alpha]=1.0
        return img
 
    def add_pepper_noise(self, img):
        noise=torch.rand(img.shape)
        alpha=np.random.random()/5 + 0.7
        img[noise[:, :, :]>alpha]=0
        return img
 
 
def plot_pics(img, boxes):
    # 显示图像和候选框,img是Image.Open()类型, boxes是Tensor类型
    plt.imshow(img)
    label_colors = [(213, 110, 89)]
    w, h = img.size
    for i in range(boxes.shape[0]):
        box = boxes[i, 1:]
        xc, yc, wc, hc = box
        x = w * xc - 0.5 * w * wc
        y = h * yc - 0.5 * h * hc
        box_w, box_h = w * wc, h * hc
        plt.gca().add_patch(plt.Rectangle(xy=(x, y), width=box_w, height=box_h,
                                          edgecolor=[c / 255 for c in label_colors[0]],
                                          fill=False, linewidth=2))
    plt.show()
 
def get_image_list(image_path):
    # 根据图片文件,查找所有图片并返回列表
    files_list = []
    for root, sub_dirs, files in os.walk(image_path):
        for special_file in files:
            special_file = special_file[0: len(special_file)]
            files_list.append(special_file)
    return files_list
 
def get_label_file(label_path, image_name):
    # 根据图片信息,查找对应的label
    fname = os.path.join(label_path, image_name[0: len(image_name)-4]+".txt")
    data2 = []
    if not os.path.exists(fname):
        return data2
    if os.path.getsize(fname) == 0:
        return data2
    else:
        with open(fname, 'r', encoding='utf-8') as infile:
            # 读取并转换标签
            for line in infile:
                data_line = line.strip("\n").split()
                data2.append([float(i) for i in data_line])
    return data2
 
 
def save_Yolo(img, boxes, save_path, prefix, image_name):
    # img: 需要时Image类型的数据, prefix 前缀
    # 将结果保存到save path指示的路径中
    if not os.path.exists(save_path) or \
            not os.path.exists(os.path.join(save_path, "images")):
        os.makedirs(os.path.join(save_path, "images"))
        os.makedirs(os.path.join(save_path, "labels"))
    try:
        img.save(os.path.join(save_path, "images", prefix + image_name))
        with open(os.path.join(save_path, "labels", prefix + image_name[0:len(image_name)-4] + ".txt"), 'w', encoding="utf-8") as f:
            if len(boxes) > 0:  # 判断是否为空
                # 写入新的label到文件中
                for data in boxes:
                    str_in = ""
                    for i, a in enumerate(data):
                        if i == 0:
                            str_in += str(int(a))
                        else:
                            str_in += " " + str(float(a))
                    f.write(str_in + '\n')
    except:
        print("ERROR: ", image_name, " is bad.")
 
 
def runAugumentation(image_path, label_path, save_path):
    image_list = get_image_list(image_path)
    for image_name in image_list:
        print("dealing: " + image_name)
        img = Image.open(os.path.join(image_path, image_name))
        boxes = get_label_file(label_path, image_name)
        boxes = torch.tensor(boxes)
        # 下面是执行的数据增强功能,可自行选择
        # Image类型的参数
        DAD = DataAugmentationOnDetection()
 
        """ 尺寸变换   """
        # 缩小尺寸
        # t_img, t_boxes = DAD.resizeDown_keep_ratio(img, boxes, 1024)
        # save_Yolo(t_img, boxes, save_path, prefix="rs_", image_name=image_name)
        # 水平旋转
        t_img, t_boxes = DAD.random_flip_horizon(img, boxes.clone())
        save_Yolo(t_img, t_boxes, save_path, prefix="fh_", image_name=image_name)
        # 竖直旋转
        t_img, t_boxes = DAD.random_flip_vertical(img, boxes.clone())
        save_Yolo(t_img, t_boxes, save_path, prefix="fv_", image_name=image_name)
        # center_crop
        t_img, t_boxes = DAD.center_crop(img, boxes.clone(), 1024)
        save_Yolo(t_img, t_boxes, save_path, prefix="cc_", image_name=image_name)
 
        """ 图像变换,用tensor类型"""
        to_tensor = transforms.ToTensor()
        to_image = transforms.ToPILImage()
        img = to_tensor(img)
 
        # random_bright
        t_img, t_boxes = DAD.random_bright(img.clone()), boxes
        save_Yolo(to_image(t_img), boxes, save_path, prefix="rb_", image_name=image_name)
        # random_contrast 对比度变化
        t_img, t_boxes = DAD.random_contrast(img.clone()), boxes
        save_Yolo(to_image(t_img), boxes, save_path, prefix="rc_", image_name=image_name)
        # random_saturation 饱和度变化
        t_img, t_boxes = DAD.random_saturation(img.clone()), boxes
        save_Yolo(to_image(t_img), boxes, save_path, prefix="rs_", image_name=image_name)
        # 高斯噪声
        t_img, t_boxes = DAD.add_gasuss_noise(img.clone()), boxes
        save_Yolo(to_image(t_img), boxes, save_path, prefix="gn_", image_name=image_name)
        # add_salt_noise
        t_img, t_boxes = DAD.add_salt_noise(img.clone()), boxes
        save_Yolo(to_image(t_img), boxes, save_path, prefix="sn_", image_name=image_name)
        # add_pepper_noise
        t_img, t_boxes = DAD.add_pepper_noise(img.clone()), boxes
        save_Yolo(to_image(t_img), boxes, save_path, prefix="pn_", image_name=image_name)
 
        print("end:     " + image_name)
 
 
if __name__ == '__main__':
    # 图像和标签文件夹
    image_path = r"./train/images"
    label_path = r"./train/labels"
    save_path = r"./save"    # 结果保存位置路径,可以是一个不存在的文件夹
    # 运行
    runAugumentation(image_path, label_path, save_path)

粘贴完成以后,如图所示:

在上方的代码中,我们找到“if __name__ == '__main__':”的位置,在使用代码前,需要对这里的路径进行一些简单的修改。来到上述位置后,我们可以看到如图所示的代码:

这里在配置之前有一个前提,那就是我们的数据集已经制作好了。我的数据集结构如图所示:

如上图可以看到,我们的“Augment.py”与train目录在同一级,在train目录中有images目录与labels目录:

这就是非常常见的YOLO目录结构,这里就不多说了。

根据代码中的变量,我们可知,第一个“image_path”需要传入我们的数据集的图片路径,后面的“label_path”需要传入我们数据集中标签的路径。最后一个“save_path”就是我们保存新生成的数据集与标签的路径。修改好上面的内容以后,我们直接保存即可。我们进入YOLO的虚拟环境,然后直接运行这个py文件即可:

python .\Augment.py

运行以后,我们就可以看到代码开始帮我们处理数据集了:

在我们设置的保存路径中,代码已经保存了帮我们生成的数据集:

等程序执行完成以后,我们可以看到对应的文件夹中有1800个文件,相当于现在我们使用代码增强出来的数据集是我们原本数据集的9倍。这些数据集已经涵盖了大部分的情况:

在labels文件夹中也生成了对应的标签:

至此,我们使用YOLO增强数据集就完成了。

四、结语

        在本次教程中,我们通过对YOLO数据集的增强,实现了数据集多样性的扩展,极大的减少了人工框选的成本以及样本拍摄的数量,那么最后,感谢大家的观看!


网站公告

今日签到

点亮在社区的每一天
去签到