YOLO v1 源码详解(二):损失函数与训练流程

发布于:2025-04-18 ⋅ 阅读:(27) ⋅ 点赞:(0)

接上篇,我们继续深入解析YOLO v1的核心实现,重点关注损失函数设计和训练流程。

YOLO v1 损失函数

YOLO v1的损失函数是整个算法的核心,它需要同时优化目标定位和分类任务。以下是损失函数的实现:


python

class YOLOLoss(nn.Module):
    def __init__(self, S=7, B=2, C=20, lambda_coord=5, lambda_noobj=0.5):
        super(YOLOLoss, self).__init__()
        self.S = S  # 网格大小
        self.B = B  # 每个网格预测的边界框数量
        self.C = C  # 类别数量
        self.lambda_coord = lambda_coord  # 坐标损失权重
        self.lambda_noobj = lambda_noobj  # 无目标网格的置信度损失权重
        self.mse = nn.MSELoss(reduction='sum')
    
    def forward(self, predictions, targets):
        """
        YOLO损失函数计算
        
        Args:
            predictions: 模型预测输出 [batch_size, S, S, B*5+C]
            targets: 真实标签 [batch_size, S, S, 5+C],格式为[obj_flag, x, y, w, h, class_probs]
        
        Returns:
            损失值
        """
        batch_size = predictions.size(0)
        
        # 将预测结果和真实值拆分成相关组件
        # 预测的边界框1: [batch_size, S, S, 5]
        pred_box1 = predictions[..., :5]
        # 预测的边界框2: [batch_size, S, S, 5]
        pred_box2 = predictions[..., 5:10]
        # 预测的类别概率: [batch_size, S, S, C]
        pred_classes = predictions[..., 10:]
        
        # 真实的边界框: [batch_size, S, S, 4]
        target_box = targets[..., 1:5]
        # 真实的类别概率: [batch_size, S, S, C]
        target_classes = targets[..., 5:]
        # 是否有目标存在于网格: [batch_size, S, S, 1]
        obj_mask = targets[..., 0].unsqueeze(3)
        noobj_mask = 1 - obj_mask
        
        # 计算预测框1和预测框2与真实框的IoU
        iou_box1 = calculate_iou(pred_box1[..., :4], target_box)
        iou_box2 = calculate_iou(pred_box2[..., :4], target_box)
        
        # 获取IoU更高的预测框
        iou_maxes, best_box = torch.max(torch.stack([iou_box1, iou_box2], dim=0), dim=0)
        
        # 创建负责预测的框的掩码(1表示第一个框,0表示第二个框)
        best_box_mask = best_box.unsqueeze(3)
        
        # 获取负责预测的框
        responsible_box = best_box_mask * pred_box1 + (1 - best_box_mask) * pred_box2
        
        # 1. 坐标损失(仅对有目标的网格)
        # 1.1 中心坐标(x,y)损失
        center_loss = self.mse(
            responsible_box[..., :2] * obj_mask,
            target_box[..., :2] * obj_mask
        )
        
        # 1.2 宽高(w,h)损失 - 使用平方根来减小大框和小框之间的差异
        width_height_loss = self.mse(
            torch.sign(responsible_box[..., 2:4]) * torch.sqrt(torch.abs(responsible_box[..., 2:4]) + 1e-6) * obj_mask,
            torch.sqrt(target_box[..., 2:4] + 1e-6) * obj_mask
        )
        
        # 2. 目标存在的网格的置信度损失
        # 2.1 有目标的框的置信度损失
        confidence_target = obj_mask * iou_maxes.unsqueeze(3)
        obj_confidence_loss = self.mse(
            responsible_box[..., 4:5] * obj_mask,
            confidence_target
        )
        
        # 2.2 无目标的框的置信度损失
        noobj_confidence_loss = self.mse(
            pred_box1[..., 4:5] * noobj_mask,
            torch.zeros_like(pred_box1[..., 4:5])
        ) + self.mse(
            pred_box2[..., 4:5] * noobj_mask,
            torch.zeros_like(pred_box2[..., 4:5])
        )
        
        # 3. 类别概率损失
        class_loss = self.mse(
            pred_classes * obj_mask,
            target_classes * obj_mask
        )
        
        # 总损失 = 坐标损失 + 置信度损失 + 类别损失
        total_loss = (
            self.lambda_coord * (center_loss + width_height_loss) +
            obj_confidence_loss +
            self.lambda_noobj * noobj_confidence_loss +
            class_loss
        )
        
        return total_loss / batch_size

IoU计算函数

IoU (Intersection over Union) 是衡量两个边界框重叠程度的指标。在损失函数中,我们需要计算预测框和真实框的IoU:


python

def calculate_iou(pred_boxes, target_boxes):
    """
    计算预测框和目标框之间的IoU
    
    Args:
        pred_boxes: 预测框坐标 [batch_size, S, S, 4] (x, y, w, h)
        target_boxes: 目标框坐标 [batch_size, S, S, 4] (x, y, w, h)
        
    Returns:
        iou: [batch_size, S, S, 1]
    """
    # 将(x, y, w, h)转换为(x1, y1, x2, y2)格式
    pred_x1 = pred_boxes[..., 0:1] - pred_boxes[..., 2:3] / 2
    pred_y1 = pred_boxes[..., 1:2] - pred_boxes[..., 3:4] / 2
    pred_x2 = pred_boxes[..., 0:1] + pred_boxes[..., 2:3] / 2
    pred_y2 = pred_boxes[..., 1:2] + pred_boxes[..., 3:4] / 2
    
    target_x1 = target_boxes[..., 0:1] - target_boxes[..., 2:3] / 2
    target_y1 = target_boxes[..., 1:2] - target_boxes[..., 3:4] / 2
    target_x2 = target_boxes[..., 0:1] + target_boxes[..., 2:3] / 2
    target_y2 = target_boxes[..., 1:2] + target_boxes[..., 3:4] / 2
    
    # 计算交集区域
    x1 = torch.max(pred_x1, target_x1)
    y1 = torch.max(pred_y1, target_y1)
    x2 = torch.min(pred_x2, target_x2)
    y2 = torch.min(pred_y2, target_y2)
    
    # 计算交集面积,确保宽高不为负
    intersection = torch.clamp(x2 - x1, min=0) * torch.clamp(y2 - y1, min=0)
    
    # 计算两个框各自的面积
    pred_area = (pred_x2 - pred_x1) * (pred_y2 - pred_y1)
    target_area = (target_x2 - target_x1) * (target_y2 - target_y1)
    
    # 计算并集面积
    union = pred_area + target_area - intersection + 1e-6  # 添加小值避免除零
    
    # 计算IoU
    iou = intersection / union
    
    return iou

数据预处理与标签转换

在训练YOLO模型时,需要将原始标注数据转换为YOLO格式的标签:


python

def convert_to_yolo_format(boxes, labels, image_size, S=7, C=20):
    """
    将边界框和标签转换为YOLO格式
    
    Args:
        boxes: 原始边界框坐标 [N, 4] (x1, y1, x2, y2),绝对像素坐标
        labels: 类别标签 [N]
        image_size: 图像大小 (height, width)
        S: 网格大小
        C: 类别数量
        
    Returns:
        target: YOLO格式的目标标签 [S, S, 5+C]
    """
    target = torch.zeros((S, S, 5 + C))
    
    # 将绝对坐标转换为相对坐标
    boxes_rel = boxes.clone()
    boxes_rel[:, [0, 2]] /= image_size[1]  # x坐标除以宽度
    boxes_rel[:, [1, 3]] /= image_size[0]  # y坐标除以高度
    
    # 将(x1, y1, x2, y2)转换为(x, y, w, h)
    boxes_xywh = torch.zeros_like(boxes_rel)
    boxes_xywh[:, 0] = (boxes_rel[:, 0] + boxes_rel[:, 2]) / 2  # 中心x
    boxes_xywh[:, 1] = (boxes_rel[:, 1] + boxes_rel[:, 3]) / 2  # 中心y
    boxes_xywh[:, 2] = boxes_rel[:, 2] - boxes_rel[:, 0]  # 宽度
    boxes_xywh[:, 3] = boxes_rel[:, 3] - boxes_rel[:, 1]  # 高度
    
    # 对于每个边界框
    for i in range(len(boxes)):
        # 确定中心点所在网格
        grid_x = int(S * boxes_xywh[i, 0])
        grid_y = int(S * boxes_xywh[i, 1])
        
        # 确保网格索引在合法范围内
        grid_x = min(grid_x, S - 1)
        grid_y = min(grid_y, S - 1)
        
        # 计算中心点相对于网格的偏移
        x_offset = boxes_xywh[i, 0] * S - grid_x
        y_offset = boxes_xywh[i, 1] * S - grid_y
        
        # 设置目标有无标志
        target[grid_y, grid_x, 0] = 1
        
        # 设置边界框相对于网格的坐标
        target[grid_y, grid_x, 1] = x_offset
        target[grid_y, grid_x, 2] = y_offset
        target[grid_y, grid_x, 3] = boxes_xywh[i, 2]  # 宽度
        target[grid_y, grid_x, 4] = boxes_xywh[i, 3]  # 高度
        
        # 设置类别概率(one-hot编码)
        target[grid_y, grid_x, 5 + labels[i]] = 1
    
    return target

数据增强

为提高模型泛化能力,YOLO v1使用多种数据增强技术:


python

def data_augmentation(image, boxes):
    """
    对图像和边界框进行数据增强
    
    Args:
        image: 输入图像 [H, W, 3]
        boxes: 边界框 [N, 4] (x1, y1, x2, y2)
        
    Returns:
        augmented_image: 增强后的图像
        augmented_boxes: 增强后的边界框
    """
    augmented_image = image.copy()
    augmented_boxes = boxes.clone()
    
    # 1. 随机调整亮度、对比度、饱和度和色调
    if random.random() < 0.5:
        augmented_image = torchvision.transforms.functional.adjust_brightness(
            augmented_image, random.uniform(0.8, 1.2)
        )
    if random.random() < 0.5:
        augmented_image = torchvision.transforms.functional.adjust_contrast(
            augmented_image, random.uniform(0.8, 1.2)
        )
    if random.random() < 0.5:
        augmented_image = torchvision.transforms.functional.adjust_saturation(
            augmented_image, random.uniform(0.8, 1.2)
        )
    if random.random() < 0.5:
        augmented_image = torchvision.transforms.functional.adjust_hue(
            augmented_image, random.uniform(-0.1, 0.1)
        )
    
    # 2. 随机水平翻转
    if random.random() < 0.5:
        augmented_image = torchvision.transforms.functional.hflip(augmented_image)
        width = augmented_image.shape[1]
        augmented_boxes[:, 0] = width - augmented_boxes[:, 0] - 1  # 反转x1
        augmented_boxes[:, 2] = width - augmented_boxes[:, 2] - 1  # 反转x2
        # 交换x1和x2以保持x1 < x2
        augmented_boxes[:, [0, 2]] = augmented_boxes[:, [2, 0]]
    
    # 3. 随机剪裁和缩放
    if random.random() < 0.5:
        height, width = augmented_image.shape[:2]
        scale = random.uniform(0.8, 1.0)
        new_height = int(height * scale)
        new_width = int(width * scale)
        
        # 随机选择剪裁区域
        x_offset = random.randint(0, width - new_width)
        y_offset = random.randint(0, height - new_height)
        
        # 剪裁图像
        augmented_image = augmented_image[y_offset:y_offset+new_height, x_offset:x_offset+new_width]
        
        # 调整边界框
        augmented_boxes[:, 0] = (augmented_boxes[:, 0] - x_offset) * width / new_width
        augmented_boxes[:, 1] = (augmented_boxes[:, 1] - y_offset) * height / new_height
        augmented_boxes[:, 2] = (augmented_boxes[:, 2] - x_offset) * width / new_width
        augmented_boxes[:, 3] = (augmented_boxes[:, 3] - y_offset) * height / new_height
        
        # 确保边界框在有效范围内
        augmented_boxes[:, 0] = torch.clamp(augmented_boxes[:, 0], min=0, max=width-1)
        augmented_boxes[:, 1] = torch.clamp(augmented_boxes[:, 1], min=0, max=height-1)
        augmented_boxes[:, 2] = torch.clamp(augmented_boxes[:, 2], min=0, max=width-1)
        augmented_boxes[:, 3] = torch.clamp(augmented_boxes[:, 3], min=0, max=height-1)
    
    # 4. 最后调整大小到标准尺寸
    augmented_image = cv2.resize(augmented_image, (448, 448))
    
    return augmented_image, augmented_boxes

训练流程

以下是完整的YOLO v1训练流程:


python

def train_yolo(model, train_loader, val_loader, num_epochs=135, lr=0.0001):
    """
    训练YOLO模型
    
    Args:
        model: YOLO模型
        train_loader: 训练数据加载器
        val_loader: 验证数据加载器
        num_epochs: 训练轮数
        lr: 初始学习率
    """
    # 设置损失函数和优化器
    criterion = YOLOLoss(S=7, B=2, C=20)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=5e-4)
    
    # 学习率调度器
    scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, 
        milestones=[75, 105], 
        gamma=0.1
    )
    
    # 开始训练
    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0
        
        for batch_idx, (images, targets) in enumerate(train_loader):
            # 图像预处理
            images = images.to(device)  # [batch_size, 3, 448, 448]
            targets = targets.to(device)  # [batch_size, S, S, 5+C]
            
            # 前向传播
            predictions = model(images)
            
            # 计算损失
            loss = criterion(predictions, targets)
            
            # 反向传播和优化
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            # 累积损失
            epoch_loss += loss.item()
            
            # 打印训练进度
            if batch_idx % 50 == 0:
                print(f"Epoch [{epoch+1}/{num_epochs}], Batch [{batch_idx}/{len(train_loader)}], Loss: {loss.item():.4f}")
        
        # 更新学习率
        scheduler.step()
        
        # 计算平均损失
        avg_loss = epoch_loss / len(train_loader)
        print(f"Epoch [{epoch+1}/{num_epochs}], Average Loss: {avg_loss:.4f}")
        
        # 每5个epoch验证一次
        if (epoch + 1) % 5 == 0:
            validate_yolo(model, val_loader, criterion)
        
        # 保存模型
        if (epoch + 1) % 10 == 0:
            torch.save(model.state_dict(), f"yolo_epoch_{epoch+1}.pth")
    
    # 保存最终模型
    torch.save(model.state_dict(), "yolo_final.pth")

验证流程

在训练过程中,定期验证模型性能:


python

def validate_yolo(model, val_loader, criterion):
    """
    验证YOLO模型性能
    
    Args:
        model: YOLO模型
        val_loader: 验证数据加载器
        criterion: 损失函数
    """
    model.eval()
    val_loss = 0
    all_detections = []
    all_ground_truths = []
    
    with torch.no_grad():
        for images, targets in val_loader:
            images = images.to(device)
            targets = targets.to(device)
            
            # 前向传播
            predictions = model(images)
            
            # 计算损失
            loss = criterion(predictions, targets)
            val_loss += loss.item()
            
            # 解码预测结果
            batch_size = images.size(0)
            for i in range(batch_size):
                # 获取预测结果
                pred = predictions[i]
                boxes, class_ids, scores = decode_predictions(
                    pred.unsqueeze(0),  # 添加批次维度
                    S=7, 
                    B=2, 
                    C=20, 
                    confidence_threshold=0.1
                )
                
                # 应用NMS
                keep_indices = non_maximum_suppression(boxes, scores, iou_threshold=0.5)
                
                if len(keep_indices) > 0:
                    boxes = boxes[keep_indices]
                    class_ids = class_ids[keep_indices]
                    scores = scores[keep_indices]
                    
                    # 存储预测结果
                    detections = []
                    for j in range(len(boxes)):
                        detections.append({
                            'bbox': boxes[j].cpu().numpy(),
                            'class_id': class_ids[j].item(),
                            'confidence': scores[j].item()
                        })
                    all_detections.append(detections)
                else:
                    all_detections.append([])
                
                # 获取真实标签
                target = targets[i]
                ground_truths = []
                for cy in range(7):
                    for cx in range(7):
                        if target[cy, cx, 0] > 0:  # 有目标存在
                            # 解码真实框
                            x = (target[cy, cx, 1] + cx) / 7
                            y = (target[cy, cx, 2] + cy) / 7
                            w = target[cy, cx, 3]
                            h = target[cy, cx, 4]
                            
                            # 转换为[x1, y1, x2, y2]格式
                            x1 = max(0, x - w/2)
                            y1 = max(0, y - h/2)
                            x2 = min(1, x + w/2)
                            y2 = min(1, y + h/2)
                            
                            # 找到类别
                            class_probs = target[cy, cx, 5:]
                            class_id = torch.argmax(class_probs).item()
                            
                            ground_truths.append({
                                'bbox': [x1, y1, x2, y2],
                                'class_id': class_id
                            })
                all_ground_truths.append(ground_truths)
    
    # 计算平均验证损失
    avg_val_loss = val_loss / len(val_loader)
    print(f"Validation Loss: {avg_val_loss:.4f}")
    
    # 计算mAP
    mAP = calculate_mAP(all_detections, all_ground_truths)
    print(f"mAP: {mAP:.4f}")
    
    return avg_val_loss, mAP

计算mAP

为评估模型性能,需要计算mAP (mean Average Precision):


python

def calculate_mAP(all_detections, all_ground_truths, iou_threshold=0.5, num_classes=20):
    """
    计算mAP (mean Average Precision)
    
    Args:
        all_detections: 所有预测结果
        all_ground_truths: 所有真实标签
        iou_threshold: IoU阈值
        num_classes: 类别数量
    
    Returns:
        mAP: 平均精度均值
    """
    # 按类别收集所有的预测和真实标签
    class_predictions = [[] for _ in range(num_classes)]
    class_ground_truths = [[] for _ in range(num_classes)]
    
    # 遍历所有图像
    for img_idx in range(len(all_detections)):
        # 获取当前图像的预测和真实标签
        detections = all_detections[img_idx]
        ground_truths = all_ground_truths[img_idx]
        
        # 处理该图像中的所有预测框
        for detection in detections:
            class_id = detection['class_id']
            confidence = detection['confidence']
            bbox = detection['bbox']
            
            class_predictions[class_id].append({
                'confidence': confidence,
                'bbox': bbox,
                'img_idx': img_idx,
                'matched': False  # 标记是否已匹配
            })
        
        # 处理该图像中的所有真实框
        for gt in ground_truths:
            class_id = gt['class_id']
            bbox = gt['bbox']
            
            class_ground_truths[class_id].append({
                'bbox': bbox,
                'img_idx': img_idx,
                'matched': False  # 标记是否已匹配
            })
    
    # 计算每个类别的AP
    APs = []
    for class_id in range(num_classes):
        # 获取当前类别的预测和真实标签
        predictions = class_predictions[class_id]
        ground_truths = class_ground_truths[class_id]
        
        # 如果没有真实标签,则跳过
        if len(ground_truths) == 0:
            continue
        
        # 按置信度从高到低排序预测框
        predictions.sort(key=lambda x: x['confidence'], reverse=True)
        
        # 计算精度和召回率点
        precisions = []
        recalls = []
        tp = 0  # 真正例数量
        fp = 0  # 假正例数量
        total_gt = len(ground_truths)
        
        for pred_idx, pred in enumerate(predictions):
            img_idx = pred['img_idx']
            pred_bbox = pred['bbox']
            
            # 寻找匹配的真实框
            matched_gt = False
            for gt in ground_truths:
                if gt['img_idx'] == img_idx and not gt['matched']:
                    gt_bbox = gt['bbox']
                    
                    # 计算IoU
                    x1 = max(pred_bbox[0], gt_bbox[0])
                    y1 = max(pred_bbox[1], gt_bbox[1])
                    x2 = min(pred_bbox[2], gt_bbox[2])
                    y2 = min(pred_bbox[3], gt_bbox[3])
                    
                    # 计算交集面积
                    intersection = max(0, x2 - x1) * max(0, y2 - y1)
                    
                    # 计算各自的面积
                    pred_area = (pred_bbox[2] - pred_bbox[0]) * (pred_bbox[3] - pred_bbox[1])
                    gt_area = (gt_bbox[2] - gt_bbox[0]) * (gt_bbox[3] - gt_bbox[1])
                    
                    # 计算并集面积
                    union = pred_area + gt_area - intersection
                    
                    # 计算IoU
                    iou = intersection / union
                    
                    if iou >= iou_threshold:
                        matched_gt = True
                        gt['matched'] = True
                        break
            
            # 更新TP和FP
            if matched_gt:
                tp += 1
            else:
                fp += 1
            
            # 计算精度和召回率
            precision = tp / (tp + fp)
            recall = tp / total_gt
            
            precisions.append(precision)
            recalls.append(recall)
        
        # 计算AP (11点插值法)
        AP = 0
        for t in np.arange(0, 1.1, 0.1):
            if np.sum(np.array(recalls) >= t) == 0:
                p = 0
            else:
                p = np.max(np.array(precisions)[np.array(recalls) >= t])
            AP += p / 11
        
        APs.append(AP)
    
    # 计算mAP
    mAP = np.mean(APs)
    
    return mAP

完整训练流程示例

下面是使用上述代码组件进行训练的完整示例:


python

if __name__ == "__main__":
    # 创建模型
    model = YOLOv1(S=7, B=2, C=20)
    model = model.to(device)
    
    # 创建数据集和数据加载器
    transform = transforms.Compose([
        transforms.Resize((448, 448)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    # 这里使用VOC数据集作为示例
    train_dataset = VOCDataset(
        root="path/to/VOC2007",
        year="2007",
        image_set="train",
        transform=transform,
        S=7, B=2, C=20
    )
    
    val_dataset = VOCDataset(
        root="path/to/VOC2007",
        year="2007",
        image_set="val",
        transform=transform,
        S=7, B=2, C=20
    )
    
    train_loader = DataLoader(
        train_dataset,
        batch_size=16,
        shuffle=True,
        num_workers=4,
        pin_memory=True,
        drop_last=True
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=16,
        shuffle=False,
        num_workers=4,
        pin_memory=True,
        drop_last=False
    )
    
    # 训练模型
    train_yolo(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        num_epochs=135,
        lr=0.0001
    )

总结

本文详细解析了YOLO v1的损失函数设计和训练流程


网站公告

今日签到

点亮在社区的每一天
去签到