【数据增强】精细化贴图数据增强

发布于:2025-07-05 ⋅ 阅读:(19) ⋅ 点赞:(0)

1.任务背景

假设我有100个苹果的照片,我需要把这些照片粘贴到传送带照片上,模拟“传送带苹果检测”场景。
这种贴图的方式更加合理一些,因为yolo之类的mosaic贴图,会把图像弄的非常支离破碎。
现在我需要随机选择几张苹果图像,每张苹果图像至少使用x次,并且保证苹果(新苹果之间、新旧标注信息之间)不重叠,且苹果大小范围可以自由指定
效果图如下(人工核查,xml也是正确的,我粘贴的是电力部件)。在真实使用中,如果需要确保信粘贴内容与已有标签不重复,请在在这里插入图片描述
):

2.具体逻辑

2.1 项目描述

这是一个用于计算机视觉任务(如目标检测)的智能数据增强脚本。它通过将小物体图像(Patches)以复杂且真实的方式粘贴到大型背景图像上,来批量生成高质量的训练数据集。脚本的核心是“场景合成”,旨在创建包含多个、尺寸合理且互不重叠物体的复杂图像,从而有效提升模型的鲁棒性和泛化能力。

2.2 整体流程

脚本的运行流程遵循一个清晰的、基于统计学控制的“事件驱动”模型:

  1. 计算任务总量: 首先,根据用户设定的参数(小图总数、期望重复次数、每背景粘贴数),脚本会计算出需要生成的增强图片总数。
  2. 循环生成: 程序会循环执行所计算出的总次数。在每一次循环中,它会独立地完成一次“场景合成”操作。
  3. 场景合成:
    • 选取素材: 随机选择一张背景图和指定数量的随机小图。
    • 智能调整与放置: 对每一个选中的小图,依次进行:
      • 动态缩放: 根据相对于背景图的“最大/最小面积百分比”要求,自动缩放小图,确保尺寸合理。
      • 碰撞检测: 为小图寻找一个随机的、且不与任何已放置小图重叠的位置。
      • 数据增强: 对小图进行随机的水平或垂直翻转。
      • 粘贴与记录: 将处理后的小图粘贴到背景图上,并同步更新XML标注信息。
    • 保存输出: 将最终合成的图片和包含所有新物体标注的XML文件,以唯一编号保存到输出文件夹。

2.3主要功能

  • 场景化组合: 能在单一背景上粘贴多个随机物体,模拟复杂的真实世界场景。
  • 动态尺寸调整: 摒弃了固定的像素限制,采用相对于背景的面积百分比来约束贴图大小,使其能智能适应任意尺寸的背景图。
  • 防重叠粘贴: 核心亮点功能。通过碰撞检测算法确保粘贴的物体之间(包括原有的标注信息)互不重叠,显著提升了生成数据的质量和真实感。
  • 可控的随机性: 用户可以通过参数精确控制最终生成数据集的总量,并使每个物体的平均使用次数在统计上趋于稳定。
  • 自动化标注: 在生成图像的同时,会自动创建和更新对应的PASCAL VOC格式的XML标注文件,省去手动标注的繁琐工作。

2.4 参数说明

  • FOLDER_A, FOLDER_B, OUTPUT...: 用于定义素材和输出结果的路径。
  • OBJECT_NAME: 指定粘贴的物体在XML文件中被称作什么类别名。
  • NUM_PATCHES_PER_BG: **(核心)**设定每张生成的图片上要粘贴多少个小图/物体。
  • REPEATS_PER_PATCH: **(核心)**设定数据集中每一种小图期望被重复使用的平均次数,用于计算总生成量。
  • MAX_AREA_PERCENTAGE: 定义贴图相对于背景的最大允许面积(例如 0.5 代表50%)。
  • MIN_AREA_PERCENTAGE: 定义贴图相对于背景的最小允许面积(例如 0.05 代表5%),小于此值会被自动放大。
  • MAX_PLACEMENT_TRIES: 在为小图寻找不重叠位置时的最大尝试次数,这是一个防止在拥挤场景下无限循环的安全设置。

3.代码实现

import os
import random
import xml.etree.ElementTree as ET
from PIL import Image
from tqdm import tqdm
import copy
import math
import cv2

# --- 配置区 ---

# 输入文件夹
FOLDER_A = r'E:\data\baidu_pic\aaa'  # 存放小图像的文件夹
FOLDER_B = r'E:\data\baidu_pic\background'  # 存放带XML的大图像的文件夹

# 输出文件夹 (如果不存在,脚本会自动创建)
OUTPUT_IMAGES_FOLDER = r'E:\data\baidu_pic\hecheng\iamges'
OUTPUT_ANNOTATIONS_FOLDER = r'E:\data\baidu_pic\hecheng\xlms'

# 要在XML中添加的Object名称
OBJECT_NAME = 'rdg'

# --- 关键参数 ---
# 1. 为每张大图粘贴多少个小图
NUM_PATCHES_PER_BG = 3

# 2. 数据集中每张小图期望重复使用的次数
# 这将决定“小图池”的大小
REPEATS_PER_PATCH = 2

# 3.新增:相对面积限制 
# 使用基于背景图面积的百分比来控制贴图的最终尺寸
MAX_AREA_PERCENTAGE = 0.01  # 贴图面积不得超过背景的1%
MIN_AREA_PERCENTAGE = 0.005 # 贴图面积不得小于背景的0.5%

# 4.支持的图像文件扩展名
SUPPORTED_IMAGE_FORMATS = ['.jpg', '.jpeg', '.png', '.bmp']

# 5.为防止因空间不足而无限循环,设定为每个贴图寻找不重叠位置的最大尝试次数
MAX_PLACEMENT_TRIES = 100

# --- 辅助函数 ---

def is_overlapping(box1, box2):
    """检查两个边界框是否重叠。 box = (xmin, ymin, xmax, ymax)"""
    if box1[2] < box2[0] or box2[2] < box1[0]:
        return False
    if box1[3] < box2[1] or box2[3] < box1[1]:
        return False
    return True

def get_b_image_basenames(folder_path):
    filenames = []
    for f in os.listdir(folder_path):
        basename, ext = os.path.splitext(f)
        if ext.lower() in SUPPORTED_IMAGE_FORMATS and os.path.exists(os.path.join(folder_path, basename + '.xml')):
            filenames.append(basename)
    return list(set(filenames))

def update_xml_annotation(xml_root, object_name, xmin, ymin, xmax, ymax):
    obj = ET.SubElement(xml_root, 'object')
    ET.SubElement(obj, 'name').text = object_name
    ET.SubElement(obj, 'pose').text = 'Unspecified'
    ET.SubElement(obj, 'truncated').text = '0'
    ET.SubElement(obj, 'difficult').text = '0'
    bndbox = ET.SubElement(obj, 'bndbox')
    ET.SubElement(bndbox, 'xmin').text = str(int(xmin))
    ET.SubElement(bndbox, 'ymin').text = str(int(ymin))
    ET.SubElement(bndbox, 'xmax').text = str(int(xmax))
    ET.SubElement(bndbox, 'ymax').text = str(int(ymax))

def find_image_ext(folder, basename):
    for ext in SUPPORTED_IMAGE_FORMATS:
        if os.path.exists(os.path.join(folder, basename + ext)):
            return ext
    return None

# --- 主逻辑区 ---

def main():
    print("--- 开始执行数据增强任务 (带完全防重叠逻辑) ---")

    os.makedirs(OUTPUT_IMAGES_FOLDER, exist_ok=True)
    os.makedirs(OUTPUT_ANNOTATIONS_FOLDER, exist_ok=True)
    
    all_small_images = [f for f in os.listdir(FOLDER_A) if os.path.splitext(f)[1].lower() in SUPPORTED_IMAGE_FORMATS]
    if not all_small_images:
        print(f"错误: 文件夹 '{FOLDER_A}' 中没有找到任何图像文件。")
        return

    background_basenames = get_b_image_basenames(FOLDER_B)
    if not background_basenames:
        print(f"错误: 文件夹 '{FOLDER_B}' 中没有找到任何带有XML配对的图像文件。")
        return

    num_total_patches = len(all_small_images)
    if NUM_PATCHES_PER_BG <= 0:
        print("错误: NUM_PATCHES_PER_BG 必须大于0。")
        return
        
    total_operations = int((num_total_patches * REPEATS_PER_PATCH) / NUM_PATCHES_PER_BG)
    print(f"将生成 {total_operations} 张增强图片。")

    for i in tqdm(range(total_operations), desc="生成增强图片中"):
        try:
            bg_basename = random.choice(background_basenames)
            bg_ext = find_image_ext(FOLDER_B, bg_basename)
            if not bg_ext: continue
            
            bg_image_path = os.path.join(FOLDER_B, bg_basename + bg_ext)
            xml_path = os.path.join(FOLDER_B, bg_basename + '.xml')
            
            background_image_pil = Image.open(bg_image_path).convert("RGBA")
            tree = ET.parse(xml_path)
            xml_root = tree.getroot()
            bg_width, bg_height = background_image_pil.size
            background_area = bg_width * bg_height

            # --- 核心改动:预加载原始XML中的所有物体边界框 ---
            placed_boxes = []
            for obj in xml_root.findall('object'):
                try:
                    bndbox = obj.find('bndbox')
                    # 将XML中的坐标文本转换为整数
                    xmin = int(float(bndbox.find('xmin').text))
                    ymin = int(float(bndbox.find('ymin').text))
                    xmax = int(float(bndbox.find('xmax').text))
                    ymax = int(float(bndbox.find('ymax').text))
                    placed_boxes.append((xmin, ymin, xmax, ymax))
                except (AttributeError, ValueError) as e:
                    print(f"\n警告: 解析背景'{bg_basename}'的XML时,有对象格式不正确,已跳过。错误: {e}")
                    continue
            # ----------------------------------------------------

            patches_to_paste = random.choices(all_small_images, k=NUM_PATCHES_PER_BG)
            
            for small_img_filename in patches_to_paste:
                small_image_path = os.path.join(FOLDER_A, small_img_filename)
                patch_cv_image = cv2.imread(small_image_path, cv2.IMREAD_UNCHANGED)
                if patch_cv_image is None: continue
                
                h, w = patch_cv_image.shape[:2]
                patch_area = w * h

                max_allowed_area = background_area * MAX_AREA_PERCENTAGE
                if patch_area > max_allowed_area:
                    scale_factor = math.sqrt(max_allowed_area / patch_area)
                    new_w, new_h = int(w * scale_factor), int(h * scale_factor)
                    patch_cv_image = cv2.resize(patch_cv_image, (new_w, new_h), interpolation=cv2.INTER_AREA)
                    h, w, patch_area = new_h, new_w, new_w * new_h

                min_required_area = background_area * MIN_AREA_PERCENTAGE
                if patch_area < min_required_area:
                    scale_factor = math.sqrt(min_required_area / patch_area)
                    new_w, new_h = int(w * scale_factor), int(h * scale_factor)
                    if new_w >= bg_width or new_h >= bg_height: continue
                    patch_cv_image = cv2.resize(patch_cv_image, (new_w, new_h), interpolation=cv2.INTER_CUBIC)

                if len(patch_cv_image.shape) == 3 and patch_cv_image.shape[2] == 4:
                    patch_cv_image = cv2.cvtColor(patch_cv_image, cv2.COLOR_BGRA2RGBA)
                else:
                    patch_cv_image = cv2.cvtColor(patch_cv_image, cv2.COLOR_BGR2RGB)
                patch_image = Image.fromarray(patch_cv_image)

                if random.choice([True, False]): patch_image = patch_image.transpose(Image.FLIP_LEFT_RIGHT)
                if random.choice([True, False]): patch_image = patch_image.transpose(Image.FLIP_TOP_BOTTOM)
                patch_width, patch_height = patch_image.size

                for _ in range(MAX_PLACEMENT_TRIES):
                    paste_x = random.randint(0, bg_width - patch_width)
                    paste_y = random.randint(0, bg_height - patch_height)
                    candidate_box = (paste_x, paste_y, paste_x + patch_width, paste_y + patch_height)

                    is_valid_placement = True
                    for existing_box in placed_boxes:
                        if is_overlapping(candidate_box, existing_box):
                            is_valid_placement = False
                            break
                    
                    if is_valid_placement:
                        background_image_pil.paste(patch_image, (paste_x, paste_y), patch_image)
                        update_xml_annotation(xml_root, OBJECT_NAME, paste_x, paste_y, candidate_box[2], candidate_box[3])
                        placed_boxes.append(candidate_box)
                        break
                else:
                    print(f"\n警告: 在尝试 {MAX_PLACEMENT_TRIES} 次后,未能为小图 '{small_img_filename}' 找到不重叠的位置。跳过此小图。")

            new_basename = f"augmented_output_{i+1}"
            output_image_path = os.path.join(OUTPUT_IMAGES_FOLDER, new_basename + bg_ext)
            output_xml_path = os.path.join(OUTPUT_ANNOTATIONS_FOLDER, new_basename + ".xml")
            
            background_image_pil.convert("RGB").save(output_image_path)
            tree.write(output_xml_path, encoding='utf-8')

        except Exception as e:
            print(f"\n在生成第 {i+1} 张图片时发生未知错误: {e}")
            continue

    print(f"\n--- 所有任务已完成!---")

if __name__ == '__main__':
    main()

网站公告

今日签到

点亮在社区的每一天
去签到