loss.py
ultralytics\models\utils\loss.py
目录
3.class RTDETRDetectionLoss(DETRLoss):
1.所需的库和模块
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
import torch
import torch.nn as nn
import torch.nn.functional as F
from ultralytics.utils.loss import FocalLoss, VarifocalLoss
from ultralytics.utils.metrics import bbox_iou
from .ops import HungarianMatcher
2.class DETRLoss(nn.Module):
# 这段代码定义了一个名为 DETRLoss 的类,用于计算 DETR(DEtection TRansformer)目标检测模型的损失函数。它包含了分类损失、边界框损失、GIoU 损失以及可选的辅助损失。
# 定义了一个继承自 PyTorch 的 nn.Module 的类 DETRLoss ,表示这是一个可训练的模块。
class DETRLoss(nn.Module):
# DETR (DEtection TRansformer) 损失类。此类计算并返回 DETR 对象检测模型的不同损失组件。它计算分类损失、边界框损失、GIoU 损失以及可选的辅助损失。
"""
DETR (DEtection TRansformer) Loss class. This class calculates and returns the different loss components for the
DETR object detection model. It computes classification loss, bounding box loss, GIoU loss, and optionally auxiliary
losses.
Attributes:
nc (int): The number of classes.
loss_gain (dict): Coefficients for different loss components.
aux_loss (bool): Whether to compute auxiliary losses.
use_fl (bool): Use FocalLoss or not.
use_vfl (bool): Use VarifocalLoss or not.
use_uni_match (bool): Whether to use a fixed layer to assign labels for the auxiliary branch.
uni_match_ind (int): The fixed indices of a layer to use if `use_uni_match` is True.
matcher (HungarianMatcher): Object to compute matching cost and indices.
fl (FocalLoss or None): Focal Loss object if `use_fl` is True, otherwise None.
vfl (VarifocalLoss or None): Varifocal Loss object if `use_vfl` is True, otherwise None.
device (torch.device): Device on which tensors are stored.
"""
# 这段代码是 DETRLoss 类的初始化方法 __init__ ,用于设置类的属性和初始化相关组件。
# 定义了 DETRLoss 类的初始化方法,接收以下参数 :
# 1.nc :类别数量,默认为 80。
# 2.loss_gain :损失权重的字典,默认为 None 。
# 3.aux_loss :是否计算辅助损失,默认为 True 。
# 4.use_fl :是否使用焦点损失(Focal Loss),默认为 True 。
# 5.use_vfl :是否使用变焦损失(Varifocal Loss),默认为 False 。
# 6.use_uni_match :是否使用固定层为辅助分支分配标签,默认为 False 。
# 7.uni_match_ind :如果 use_uni_match 为 True ,则指定固定层的索引,默认为 0。
def __init__(
self, nc=80, loss_gain=None, aux_loss=True, use_fl=True, use_vfl=False, use_uni_match=False, uni_match_ind=0
):
# 使用可自定义的组件和增益初始化 DETR 损失函数。
# 如果未提供,则使用默认的 loss_gain。使用预设的成本增益初始化 HungarianMatcher。支持辅助损失和各种损失类型。
"""
Initialize DETR loss function with customizable components and gains.
Uses default loss_gain if not provided. Initializes HungarianMatcher with
preset cost gains. Supports auxiliary losses and various loss types.
Args:
nc (int): Number of classes.
loss_gain (dict): Coefficients for different loss components.
aux_loss (bool): Use auxiliary losses from each decoder layer.
use_fl (bool): Use FocalLoss.
use_vfl (bool): Use VarifocalLoss.
use_uni_match (bool): Use fixed layer for auxiliary branch label assignment.
uni_match_ind (int): Index of fixed layer for uni_match.
"""
# 调用父类 nn.Module 的初始化方法,确保继承自 PyTorch 的模块初始化逻辑。
super().__init__()
# 如果未提供 loss_gain 参数,则使用默认的损失权重字典。
if loss_gain is None:
# 这些权重用于 平衡不同损失组件的贡献 。
# "class" :分类损失权重为 1。
# "bbox" :边界框损失权重为 5。
# "giou" :GIoU 损失权重为 2。
# "no_object" :无目标分类损失权重为 0.1。
# "mask" :掩码损失权重为 1。
# "dice" :Dice 损失权重为 1。
loss_gain = {"class": 1, "bbox": 5, "giou": 2, "no_object": 0.1, "mask": 1, "dice": 1}
# 将 类别数量 nc 赋值给类的属性 self.nc ,用于后续计算分类损失。
self.nc = nc
# 初始化一个 HungarianMatcher 对象,用于 计算预测框和目标框之间的匹配成本 和 匹配索引 。 cost_gain 是匹配成本的权重,具体为 :
# "class" :分类成本权重为 2。
# "bbox" :边界框成本权重为 5。
# "giou" :GIoU 成本权重为 2。
# class HungarianMatcher(nn.Module):
# -> 用于实现基于匈牙利算法的目标检测匹配机制。它主要用于计算预测框和真实框之间的匹配关系,以便在目标检测任务中优化模型的训练过程。
# -> def __init__(self, cost_gain=None, use_fl=True, with_mask=False, num_sample_points=12544, alpha=0.25, gamma=2.0):
self.matcher = HungarianMatcher(cost_gain={"class": 2, "bbox": 5, "giou": 2})
# 将 损失权重字典 loss_gain 赋值给类的属性 self.loss_gain ,用于后续计算不同损失组件时的权重调整。
self.loss_gain = loss_gain
# 将 是否计算辅助损失的标志 aux_loss 赋值给类的属性 self.aux_loss 。辅助损失通常用于 Transformer 的解码器中间层,以增强训练过程。
self.aux_loss = aux_loss
# 根据 use_fl 的值决定是否初始化一个 FocalLoss 对象。如果 use_fl 为 True ,则初始化 FocalLoss ;否则,将其设置为 None 。焦点损失用于解决 分类任务中类别不平衡的问题 。
# class FocalLoss(nn.Module):
# -> 用于实现焦点损失(Focal Loss)功能,它是一种改进版的交叉熵损失函数,主要用于解决类别不平衡问题,尤其适用于目标检测和分类任务。
# -> def __init__(self):
self.fl = FocalLoss() if use_fl else None
# 根据 use_vfl 的值决定是否初始化一个 VarifocalLoss 对象。如果 use_vfl 为 True ,则初始化 VarifocalLoss ;否则,将其设置为 None 。变焦损失是焦点损失的变体,适用于某些特定任务。
# class VarifocalLoss(nn.Module):
# -> 用于计算一种变焦损失(Varifocal Loss),通常用于目标检测或分类任务中,以优化分类得分的预测。
# -> def __init__(self):
self.vfl = VarifocalLoss() if use_vfl else None
# 将 是否使用固定层分配标签 的标志 use_uni_match 赋值给类的属性 self.use_uni_match 。如果为 True ,则在辅助分支中使用固定层的索引进行标签分配。
self.use_uni_match = use_uni_match
# 将 固定层的索引 uni_match_ind 赋值给类的属性 self.uni_match_ind 。只有当 use_uni_match 为 True 时,该属性才会被使用。
self.uni_match_ind = uni_match_ind
# 初始化设备属性 self.device 为 None 。该属性稍后会在 forward 方法中被设置为实际的设备(如 CPU 或 GPU),用于确保张量操作在正确的设备上执行。
self.device = None
# 这段代码初始化了 DETRLoss 类的核心属性和组件,包括。类别数量和损失权重。匹配器( HungarianMatcher )用于计算预测框和目标框之间的匹配关系。是否使用辅助损失、焦点损失和变焦损失。是否使用固定层分配标签的标志及其索引。初始化设备属性,稍后用于确保张量操作在正确的设备上执行。这些初始化操作为后续的损失计算提供了必要的配置和工具。
# 这段代码定义了 DETRLoss 类中的 _get_loss_class 方法,用于计算分类损失(Classification Loss)。
# 定义了一个私有方法 _get_loss_class ,用于计算分类损失。它接收以下参数 :
# 1.pred_scores :预测的类别分数,形状为 [batch_size, num_queries, num_classes] 。
# 1.targets :目标类别标签,形状为 [batch_size, num_queries] 。
# 2.gt_scores :目标分数(通常用于计算目标框的置信度),形状为 [batch_size, num_queries] 。
# 3.num_gts :目标框的数量。
# 4.postfix :损失名称的后缀,默认为空字符串。
def _get_loss_class(self, pred_scores, targets, gt_scores, num_gts, postfix=""):
# 根据预测、目标值和基本事实分数计算分类损失。
"""Computes the classification loss based on predictions, target values, and ground truth scores."""
# Logits: [b, query, num_classes], gt_class: list[[n, 1]]
# 定义了 分类损失的名称 ,包含后缀(如果有),用于 区分主损失和辅助损失 。
name_class = f"loss_class{postfix}"
# 提取 批次大小 ( bs )和 查询数量 ( nq ),即 pred_scores 的前两个维度。
bs, nq = pred_scores.shape[:2]
# one_hot = F.one_hot(targets, self.nc + 1)[..., :-1] # (bs, num_queries, num_classes)
# 将 目标类别标签 targets 转换为 one-hot 编码。
# 创建一个形状为 [batch_size, num_queries, num_classes + 1] 的零张量,用于存储 one-hot 编码。
one_hot = torch.zeros((bs, nq, self.nc + 1), dtype=torch.int64, device=targets.device)
# 使用 scatter_ 方法将目标类别标签的索引位置设置为 1。
one_hot.scatter_(2, targets.unsqueeze(-1), 1)
# 去掉最后一个维度(通常用于背景类别),最终形状为 [batch_size, num_queries, num_classes] 。
one_hot = one_hot[..., :-1]
# 将 目标分数 gt_scores 与 one-hot 编码相乘,生成 目标分数张量 。这一步确保 目标分数只与对应的目标类别相关联 。
gt_scores = gt_scores.view(bs, nq, 1) * one_hot
# 这段代码是 _get_loss_class 方法的核心部分,用于根据配置动态选择分类损失的计算方式。
# 判断是否使用焦点损失( FocalLoss )。如果 self.fl 不为 None ,则表示焦点损失已经被初始化并可用。
if self.fl:
# 如果 目标框数量 num_gts 大于零,并且变焦损失( VarifocalLoss )已经被初始化( self.vfl 不为 None ),则使用变焦损失计算分类损失。
if num_gts and self.vfl:
# 调用 变焦损失函数 VarifocalLoss 计算 分类损失 。变焦损失是焦点损失的改进版本,适用于某些特定任务。它需要以下输入 :
# pred_scores :预测的类别分数。
# gt_scores :目标分数(与 one-hot 编码的目标类别相乘后的结果)。
# one_hot :目标类别的 one-hot 编码。
loss_cls = self.vfl(pred_scores, gt_scores, one_hot)
# 如果未启用变焦损失( self.vfl 为 None )。
else:
# 则使用 焦点损失函数 FocalLoss 计算分类损失。焦点损失需要以下输入 :
# pred_scores :预测的类别分数。
# one_hot.float() :目标类别的 one-hot 编码(转换为浮点类型)。
loss_cls = self.fl(pred_scores, one_hot.float())
# 对计算得到的 分类损失 进行 归一化处理 。
# max(num_gts, 1) :确保目标框数量不为零,避免除以零的错误。
# nq :查询数量( num_queries ),用于对损失进行缩放,使其与查询数量无关。
loss_cls /= max(num_gts, 1) / nq
# 如果未启用焦点损失( self.fl 为 None )。
else:
# 则使用标准的二值交叉熵损失( BCEWithLogitsLoss )计算分类损失。
# reduction="none" :不对损失进行归约,保留每个查询的损失值。
# mean(1) :对每个查询的损失取平均。
# sum() :对所有查询的损失求和。
loss_cls = nn.BCEWithLogitsLoss(reduction="none")(pred_scores, gt_scores).mean(1).sum() # YOLO CLS loss
# 段代码的核心逻辑是根据配置动态选择分类损失的计算方式。焦点损失( FocalLoss ):如果启用变焦损失( VarifocalLoss ),则使用变焦损失计算。否则,使用焦点损失计算。计算后的损失通过目标框数量和查询数量进行归一化,以确保损失值的稳定性。二值交叉熵损失( BCEWithLogitsLoss ):如果未启用焦点损失,则使用标准的二值交叉熵损失计算分类损失。归一化处理:使用 max(num_gts, 1) / nq 对损失进行归一化,避免目标框数量为零时的数值问题,并确保损失与查询数量无关。
# 返回一个字典,包含 分类损失的名称 和 值 。损失值乘以对应的权重 self.loss_gain["class"] ,以 平衡不同损失组件的贡献 。
return {name_class: loss_cls.squeeze() * self.loss_gain["class"]}
# 这段代码实现了分类损失的计算,支持以下功能。目标类别标签的 one-hot 编码:将目标类别标签转换为 one-hot 编码,以便与预测分数进行比较。焦点损失和变焦损失的支持:根据配置动态选择焦点损失或变焦损失,或者使用标准的二值交叉熵损失。损失归一化:通过目标框数量进行归一化,避免目标框数量为零时的数值问题。权重调整:将分类损失乘以对应的权重,以平衡不同损失组件的贡献。该方法是 DETRLoss 类的核心部分之一,用于计算分类任务的损失值。
# 这段代码定义了 DETRLoss 类中的 _get_loss_bbox 方法,用于计算边界框损失(Bounding Box Loss)和 GIoU 损失(Generalized Intersection over Union Loss)。
# 定义了一个私有方法 _get_loss_bbox ,用于计算边界框损失和 GIoU 损失。它接收以下参数 :
# 1.pred_bboxes :预测的边界框,形状为 [batch_size, num_queries, 4] 。
# 2.gt_bboxes :目标边界框,形状为 [num_gts, 4] 。
# 3.postfix :损失名称的后缀,默认为空字符串。
def _get_loss_bbox(self, pred_bboxes, gt_bboxes, postfix=""):
# 计算预测和地面真实边界框的边界框和 GIoU 损失。
"""Computes bounding box and GIoU losses for predicted and ground truth bounding boxes."""
# Boxes: [b, query, 4], gt_bbox: list[[n, 4]]
# 定义了 边界框损失 和 GIoU 损失 的 名称 ,包含后缀(如果有),用于区分主损失和辅助损失。
name_bbox = f"loss_bbox{postfix}"
name_giou = f"loss_giou{postfix}"
# 初始化一个字典来存储损失。
loss = {}
# 如果目标边界框的数量为零,则将边界框损失和 GIoU 损失设置为 0,并返回。
if len(gt_bboxes) == 0:
loss[name_bbox] = torch.tensor(0.0, device=self.device)
loss[name_giou] = torch.tensor(0.0, device=self.device)
return loss
# 计算 边界框损失 。
# 使用 PyTorch 的 F.l1_loss 计算预测边界框和目标边界框之间的 L1 损失。 将损失乘以对应的权重 self.loss_gain["bbox"] 。 将损失除以目标边界框的数量 len(gt_bboxes) ,进行归一化。
loss[name_bbox] = self.loss_gain["bbox"] * F.l1_loss(pred_bboxes, gt_bboxes, reduction="sum") / len(gt_bboxes)
# 计算 GIoU 损失 。
# 使用 bbox_iou 函数计算预测边界框和目标边界框之间的 GIoU。 GIoU 的值范围在 -1 到 1 之间,因此使用 1.0 - bbox_iou 计算损失。
# def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7):
# -> 用于计算两个边界框(bounding boxes)之间的交并比(IoU)以及其变体(GIoU、DIoU、CIoU)。
# -> return iou - (rho2 / c2 + v * alpha) # CIoU
# -> return iou - rho2 / c2 # DIoU
# -> eturn iou - (c_area - union) / c_area # GIoU
# -> return iou # IoU
loss[name_giou] = 1.0 - bbox_iou(pred_bboxes, gt_bboxes, xywh=True, GIoU=True)
# 对 GIoU 损失进行归一化,将损失除以目标边界框的数量 len(gt_bboxes) 。
loss[name_giou] = loss[name_giou].sum() / len(gt_bboxes)
# 将 GIoU 损失乘以 对应的权重 self.loss_gain["giou"] 。
loss[name_giou] = self.loss_gain["giou"] * loss[name_giou]
# 返回损失字典,移除多余的维度。
return {k: v.squeeze() for k, v in loss.items()}
# 这段代码实现了边界框损失和 GIoU 损失的计算,支持以下功能。边界框损失:使用 L1 损失计算预测边界框和目标边界框之间的差异,并乘以对应的权重进行调整。GIoU 损失:使用 GIoU 计算预测边界框和目标边界框之间的重叠度,并乘以对应的权重进行调整。归一化处理:将边界框损失和 GIoU 损失除以目标边界框的数量,避免目标框数量为零时的数值问题。该方法是 DETRLoss 类的核心部分之一,用于计算边界框相关的损失值。
# This function is for future RT-DETR Segment models
# def _get_loss_mask(self, masks, gt_mask, match_indices, postfix=''):
# # masks: [b, query, h, w], gt_mask: list[[n, H, W]]
# name_mask = f'loss_mask{postfix}'
# name_dice = f'loss_dice{postfix}'
#
# loss = {}
# if sum(len(a) for a in gt_mask) == 0:
# loss[name_mask] = torch.tensor(0., device=self.device)
# loss[name_dice] = torch.tensor(0., device=self.device)
# return loss
#
# num_gts = len(gt_mask)
# src_masks, target_masks = self._get_assigned_bboxes(masks, gt_mask, match_indices)
# src_masks = F.interpolate(src_masks.unsqueeze(0), size=target_masks.shape[-2:], mode='bilinear')[0]
# # TODO: torch does not have `sigmoid_focal_loss`, but it's not urgent since we don't use mask branch for now.
# loss[name_mask] = self.loss_gain['mask'] * F.sigmoid_focal_loss(src_masks, target_masks,
# torch.tensor([num_gts], dtype=torch.float32))
# loss[name_dice] = self.loss_gain['dice'] * self._dice_loss(src_masks, target_masks, num_gts)
# return loss
# This function is for future RT-DETR Segment models
# @staticmethod
# def _dice_loss(inputs, targets, num_gts):
# inputs = F.sigmoid(inputs).flatten(1)
# targets = targets.flatten(1)
# numerator = 2 * (inputs * targets).sum(1)
# denominator = inputs.sum(-1) + targets.sum(-1)
# loss = 1 - (numerator + 1) / (denominator + 1)
# return loss.sum() / num_gts
# 这段代码定义了 DETRLoss 类中的 _get_loss_aux 方法,用于计算辅助损失(Auxiliary Losses)。辅助损失通常用于 Transformer 的解码器中间层,以增强模型的训练效果。
# 定义了 _get_loss_aux 方法,用于计算辅助损失。它接收以下参数 :
# 1.pred_bboxes :预测的边界框,形状为 [num_layers, batch_size, num_queries, 4] 。
# 2.pred_scores :预测的类别分数,形状为 [num_layers, batch_size, num_queries, num_classes] 。
# 3.gt_bboxes :目标边界框。
# 4.gt_cls :目标类别。
# 5.gt_groups :每张图片的目标框数量。
# 6.match_indices :匹配索引(可选)。
# 7.postfix :损失名称的后缀(可选)。
# 8.masks 和 9.gt_mask :掩码相关参数(可选,用于分割任务)。
def _get_loss_aux(
self,
pred_bboxes,
pred_scores,
gt_bboxes,
gt_cls,
gt_groups,
match_indices=None,
postfix="",
masks=None,
gt_mask=None,
):
# 获取辅助损失。
"""Get auxiliary losses."""
# NOTE: loss class, bbox, giou, mask, dice
# 初始化一个零张量 loss ,用于 存储不同类型的辅助损失 。
# 如果提供了掩码( masks 和 gt_mask ),则初始化长度为 5 的张量,用于存储 分类损失 、 边界框损失 、 GIoU 损失 、 掩码损失 和 Dice 损失 。
# 否则,初始化长度为 3 的张量,仅用于存储 分类损失 、 边界框损失 和 GIoU 损失 。
loss = torch.zeros(5 if masks is not None else 3, device=pred_bboxes.device)
# 这段代码是 _get_loss_aux 方法中的一部分,用于在辅助损失计算时处理匹配索引( match_indices )。它的作用是根据配置动态决定是否使用固定层的预测结果来计算匹配索引。
# 如果 match_indices 为 None ,表示 没有预先提供匹配索引 。 同时, self.use_uni_match 为 True ,表示 启用了固定层匹配功能 。
if match_indices is None and self.use_uni_match:
# 匹配索引的计算。
# self.matcher :调用 HungarianMatcher 对象(匈牙利匹配器),用于计算预测框和目标框之间的最优匹配。
match_indices = self.matcher(
# pred_bboxes[self.uni_match_ind] 和 pred_scores[self.uni_match_ind] :使用 self.uni_match_ind 指定的固定层的 预测边界框 和 预测分数 作为输入。 这意味着在辅助损失计算中,只使用解码器的某一层(通常是中间层)的预测结果来计算匹配索引。
pred_bboxes[self.uni_match_ind],
pred_scores[self.uni_match_ind],
# gt_bboxes , gt_cls , gt_groups : 目标边界框 、 目标类别 和 目标组 (每张图片的目标框数量)作为匹配器的输入。
gt_bboxes,
gt_cls,
gt_groups,
# masks 和 gt_mask :如果提供了 掩码 ( masks )和 目标掩码 ( gt_mask ),则将固定层的掩码传递给匹配器。 如果没有掩码, masks 设置为 None 。
masks=masks[self.uni_match_ind] if masks is not None else None,
gt_mask=gt_mask,
)
# 固定层匹配的启用。如果 self.use_uni_match 为 True ,则在辅助损失计算中仅使用解码器的某一层(由 self.uni_match_ind 指定)的预测结果来计算匹配索引。这种设计允许在辅助分支中使用特定层的预测结果,而不是使用最后一层的预测结果,从而增强模型对中间层的监督。动态匹配索引计算:如果没有预先提供匹配索引( match_indices 为 None ),则动态计算匹配索引。匹配索引的计算基于匈牙利算法,通过最小化预测框和目标框之间的匹配成本来实现。掩码支持(可选):如果任务涉及掩码(例如实例分割),则可以将掩码信息传递给匹配器,以增强匹配的准确性。
# 增强中间层的监督:在 DETR 模型中,辅助损失通常用于解码器的中间层,以增强模型的训练效果。使用固定层的预测结果计算匹配索引,可以确保中间层的预测结果得到充分的监督,从而提高模型的整体性能。灵活性:通过 self.use_uni_match 和 self.uni_match_ind ,用户可以灵活选择是否启用固定层匹配以及指定哪一层的预测结果用于匹配。掩码支持:如果任务涉及掩码(例如实例分割),则可以将掩码信息纳入匹配过程,从而提高匹配的准确性和任务的性能。这种设计方式使得 _get_loss_aux 方法能够灵活地计算辅助损失,同时支持多种任务需求(如目标检测和实例分割)。
# 这段代码是 _get_loss_aux 方法的核心部分,用于计算每一层的辅助损失,并将它们累加起来。
# 循环遍历每一层的预测结果。
# pred_bboxes 和 pred_scores 是包含多层预测结果的列表或张量,分别表示 预测的边界框 和 类别分数 。 使用 zip 函数将它们配对,并通过 enumerate 获取 每一层的索引 i 和 对应的预测结果 (aux_bboxes, aux_scores) 。
for i, (aux_bboxes, aux_scores) in enumerate(zip(pred_bboxes, pred_scores)):
# 提取 当前层的掩码 (如果存在)。 如果提供了掩码列表 masks ,则提取第 i 层的掩码 masks[i] 。 如果没有提供掩码( masks 为 None ),则将 aux_masks 设置为 None 。 这一步是为了支持掩码相关任务(如实例分割),但当前代码中掩码相关的损失计算被注释掉了。
aux_masks = masks[i] if masks is not None else None
# 调用 _get_loss 方法计算当前层的损失。 _get_loss 方法计算 分类损失 、 边界框损失 和 GIoU 损失 。 返回值 loss_ 是一个字典,包含当前层的 分类损失 、 边界框损失 和 GIoU 损失 。
loss_ = self._get_loss(
# 当前层的 预测边界框 aux_bboxes 和 预测分数 aux_scores 。
aux_bboxes,
aux_scores,
# 目标边界框 gt_bboxes 、 目标类别 gt_cls 和 目标组 gt_groups 。
gt_bboxes,
gt_cls,
gt_groups,
# 当前层的掩码 aux_masks 和 目标掩码 gt_mask (如果存在)。
masks=aux_masks,
gt_mask=gt_mask,
# 损失名称的后缀 postfix 和 匹配索引 match_indices 。
postfix=postfix,
match_indices=match_indices,
)
# 将当前层的损失累加到总损失中。 每一层的损失通过 _get_loss 方法计算,并通过后缀 postfix 区分不同的损失来源。
# 累加分类损失。
loss[0] += loss_[f"loss_class{postfix}"]
# 累加边界框损失。
loss[1] += loss_[f"loss_bbox{postfix}"]
# 累加 GIoU 损失。
loss[2] += loss_[f"loss_giou{postfix}"]
# if masks is not None and gt_mask is not None:
# loss_ = self._get_loss_mask(aux_masks, gt_mask, match_indices, postfix)
# loss[3] += loss_[f'loss_mask{postfix}']
# loss[4] += loss_[f'loss_dice{postfix}']
# 将累加的损失封装到一个字典中。 字典的键是 辅助损失的名称 ,包含后缀 postfix 。 字典的值是 累加后的损失值 。 这样设计是为了在主损失和辅助损失之间保持一致的命名格式。
loss = {
f"loss_class_aux{postfix}": loss[0],
f"loss_bbox_aux{postfix}": loss[1],
f"loss_giou_aux{postfix}": loss[2],
}
# if masks is not None and gt_mask is not None:
# loss[f'loss_mask_aux{postfix}'] = loss[3]
# loss[f'loss_dice_aux{postfix}'] = loss[4]
# 返回包含辅助损失的字典。
return loss
# 这段代码实现了以下功能。逐层计算辅助损失:遍历每一层的预测结果,调用 _get_loss 方法计算分类损失、边界框损失和 GIoU 损失。支持掩码相关任务(虽然当前代码中掩码部分被注释掉了)。损失累加:将每一层的损失累加到总损失中,最终返回一个包含所有辅助损失的字典。灵活性:通过 postfix 参数支持主损失和辅助损失的区分。通过 masks 和 gt_mask 支持掩码相关任务(可选)。解耦设计:使用 _get_loss 方法封装单层损失的计算逻辑,使得代码更加模块化和易于扩展。这种设计方式使得 _get_loss_aux 方法能够高效地计算辅助损失,并为 Transformer 解码器的中间层提供额外的监督信号,从而增强模型的训练效果。
# 这段代码实现了辅助损失的计算,用于 Transformer 解码器的中间层。它支持以下功能。多层损失计算:遍历每一层的预测结果,计算分类损失、边界框损失和 GIoU 损失。固定层匹配:如果启用了固定层匹配( self.use_uni_match ),则使用指定层的预测结果计算匹配索引。掩码损失(可选):如果提供了掩码和目标掩码,还可以计算掩码损失和 Dice 损失(这部分代码被注释掉了)。损失累加:将每一层的损失累加,最终返回一个包含所有辅助损失的字典。辅助损失的目的是通过在中间层引入额外的监督信号,增强模型的训练效果,从而提高模型的整体性能。
# 这段代码定义了 DETRLoss 类中的一个静态方法 _get_index ,用于从匹配索引 match_indices 中提取批次索引、源索引和目标索引。这些索引用于将预测结果与目标数据对齐。
@staticmethod
# 定义了一个静态方法 _get_index ,接收 match_indices 作为输入。
# 1.match_indices :是一个列表,其中每个元素是一个元组 (src_idx, dst_idx) ,分别表示预测框和目标框的索引。
def _get_index(match_indices):
# 从提供的匹配索引返回批量索引、源索引和目标索引。
"""Returns batch indices, source indices, and destination indices from provided match indices."""
# 提取 批次索引 。
# 遍历 match_indices ,对于每个元组 (src, _) ,使用 torch.full_like(src, i) 创建一个与 src 形状相同的张量,其值为当前批次索引 i 。
# 使用 torch.cat 将所有批次索引拼接成一个完整的张量 batch_idx 。
# 这一步的目的是为每个匹配的预测框分配对应的批次索引。
batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(match_indices)])
# 提取 源索引 。
# 遍历 match_indices ,提取每个元组中的 src (预测框的索引)。
# 使用 torch.cat 将所有源索引拼接成一个完整的张量 src_idx 。
# 这一步的目的是获取所有匹配的预测框索引。
src_idx = torch.cat([src for (src, _) in match_indices])
# 提取 目标索引 。
# 遍历 match_indices ,提取每个元组中的 dst (目标框的索引)。
# 使用 torch.cat 将所有目标索引拼接成一个完整的张量 dst_idx 。
# 这一步的目的是获取所有匹配的目标框索引。
dst_idx = torch.cat([dst for (_, dst) in match_indices])
# 返回一个元组 (batch_idx, src_idx) 和 dst_idx ,分别表示 批次索引 、 源索引 和 目标索引 。 这些索引用于将预测框与目标框对齐,以便计算损失。
return (batch_idx, src_idx), dst_idx
# 这段代码的作用是从匹配索引 match_indices 中提取批次索引、源索引和目标索引。它的主要功能包括。批次索引提取:为每个匹配的预测框分配对应的批次索引,确保索引与批次信息对齐。源索引提取:提取所有匹配的预测框索引,用于定位预测结果。目标索引提取:提取所有匹配的目标框索引,用于定位目标数据。对齐索引:返回的索引用于将预测框与目标框对齐,以便在后续步骤中计算损失。这种设计方式使得索引提取过程简洁高效,适用于 DETR 模型中预测框与目标框的匹配和对齐操作。
# 这段代码定义了 DETRLoss 类中的 _get_assigned_bboxes 方法,用于根据匹配索引 match_indices 提取预测边界框和目标边界框的匹配部分。
# 定义了 _get_assigned_bboxes 方法,接收以下参数 :
# 1.pred_bboxes :预测的边界框,形状为 [batch_size, num_queries, 4] 。
# 2.gt_bboxes :目标边界框,形状为 [batch_size, num_gts, 4] 。
# 3.match_indices :匹配索引,表示预测框和目标框之间的对应关系。
def _get_assigned_bboxes(self, pred_bboxes, gt_bboxes, match_indices):
# 根据匹配索引将预测边界框分配给地面真实边界框。
"""Assigns predicted bounding boxes to ground truth bounding boxes based on the match indices."""
# 提取 匹配的预测边界框 。
# 遍历 pred_bboxes 和 match_indices , t 表示 当前批次的预测边界框 , (i, _) 表示 匹配索引 。
# 对于每个批次 :如果匹配索引 i 的长度大于 0,则从 预测边界框 t 中提取索引为 i 的框。 如果匹配索引为空( len(i) == 0 ),则生成一个形状为 [0, t.shape[-1]] 的零张量,表示没有匹配的预测框。
# 使用 torch.cat 将 所有批次的匹配预测框 拼接成一个完整的张量 pred_assigned 。
pred_assigned = torch.cat(
[
t[i] if len(i) > 0 else torch.zeros(0, t.shape[-1], device=self.device)
for t, (i, _) in zip(pred_bboxes, match_indices)
]
)
# 提取 匹配的目标边界框 。
# 遍历 gt_bboxes 和 match_indices , t 表示 当前批次的目标边界框 , (_, j) 表示 匹配索引 。
# 对于每个批次 :如果匹配索引 j 的长度大于 0,则从 目标边界框 t 中提取索引为 j 的框。 如果匹配索引为空( len(j) == 0 ),则生成一个形状为 [0, t.shape[-1]] 的零张量,表示没有匹配的目标框。
# 使用 torch.cat 将 所有批次的匹配目标框 拼接成一个完整的张量 gt_assigned 。
gt_assigned = torch.cat(
[
t[j] if len(j) > 0 else torch.zeros(0, t.shape[-1], device=self.device)
for t, (_, j) in zip(gt_bboxes, match_indices)
]
)
# 返回两个张量 pred_assigned 和 gt_assigned ,分别表示 匹配的预测边界框 和 目标边界框 。 这些张量用于后续的损失计算,确保只对匹配的框进行计算。
return pred_assigned, gt_assigned
# 这段代码的作用是根据匹配索引 match_indices 提取预测边界框和目标边界框的匹配部分。它的主要功能包括。预测框提取:遍历每个批次的预测框,根据匹配索引提取对应的框。如果没有匹配的框,则生成零张量以保持张量形状一致。目标框提取:遍历每个批次的目标框,根据匹配索引提取对应的框。如果没有匹配的框,则生成零张量以保持张量形状一致。拼接结果:使用 torch.cat 将所有批次的匹配框拼接成一个完整的张量,便于后续的损失计算。这种设计方式确保了预测框和目标框的对齐,使得损失计算只针对匹配的部分进行,从而提高了计算效率和模型的训练效果。
# 这段代码定义了 DETRLoss 类中的 _get_loss 方法,用于计算主损失,包括分类损失、边界框损失和可选的掩码损失。该方法的核心功能是根据匹配索引对预测结果和目标数据进行对齐,并计算相应的损失。
# 定义了 _get_loss 方法,用于计算主损失。它接收以下参数 :
# 1.pred_bboxes :预测的边界框。
# 2.pred_scores :预测的类别分数。
# 3.gt_bboxes :目标边界框。
# 4.gt_cls :目标类别。
# 5.gt_groups :每张图片的目标框数量。
# 6.masks 和 7.gt_mask :掩码相关参数(可选)。
# 8.postfix :损失名称的后缀(可选)。
# 9.match_indices :匹配索引(可选)。
def _get_loss(
self,
pred_bboxes,
pred_scores,
gt_bboxes,
gt_cls,
gt_groups,
masks=None,
gt_mask=None,
postfix="",
match_indices=None,
):
# 计算损失。
"""Get losses."""
# 计算 匹配索引 。
# 如果未提供匹配索引( match_indices 为 None ),则调用 self.matcher (匈牙利匹配器)计算 预测框和目标框之间的最优匹配 。 匹配器的输入包括 预测框 、 预测分数 、 目标框 、 目标类别 、 目标组 ,以及 可选的掩码参数 。
if match_indices is None:
match_indices = self.matcher(
pred_bboxes, pred_scores, gt_bboxes, gt_cls, gt_groups, masks=masks, gt_mask=gt_mask
)
# 提取 索引 。
# 调用 _get_index 方法从匹配索引中提取 批次索引 idx 和 目标索引 gt_idx 。 idx 包含 批次索引 和 源索引 ,用于 定位预测框 ; gt_idx 用于 定位目标框 。
idx, gt_idx = self._get_index(match_indices)
# 对齐预测框和目标框。
# 使用 提取的索引 对预测框和目标框进行对齐,确保它们一一对应。 这一步是后续损失计算的基础,确保只对匹配的框进行计算。
pred_bboxes, gt_bboxes = pred_bboxes[idx], gt_bboxes[gt_idx]
# 提取 批次大小 和 查询数量 。 bs 表示批次大小。 nq 表示每个批次的查询数量(即 预测框的数量 )。
bs, nq = pred_scores.shape[:2]
# 生成 目标类别标签 。创建一个形状为 [bs, nq] 的张量 targets ,初始值为背景类别( self.nc )。 使用匹配索引 idx 和 gt_idx ,将 目标类别 gt_cls 的值 填充到对应的位置 。 这一步确保目标类别标签与预测分数对齐。
targets = torch.full((bs, nq), self.nc, device=pred_scores.device, dtype=gt_cls.dtype)
targets[idx] = gt_cls[gt_idx]
# 初始化 目标分数张量 。创建一个形状为 [bs, nq] 的零张量 gt_scores ,用于 存储目标分数 。
gt_scores = torch.zeros([bs, nq], device=pred_scores.device)
# 计算 目标分数 。
# 如果存在目标框( len(gt_bboxes) > 0 ),则计算预测框和目标框之间的 IoU(交并比)。 使用 bbox_iou 函数计算 IoU,并将结果填充到 gt_scores 的对应位置。 这一步的目标分数用于后续的分类损失计算。
if len(gt_bboxes):
# def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7):
# -> 用于计算两个边界框(bounding boxes)之间的交并比(IoU)以及其变体(GIoU、DIoU、CIoU)。
# -> return iou - (rho2 / c2 + v * alpha) # CIoU
# -> return iou - rho2 / c2 # DIoU
# -> eturn iou - (c_area - union) / c_area # GIoU
# -> return iou # IoU
gt_scores[idx] = bbox_iou(pred_bboxes.detach(), gt_bboxes, xywh=True).squeeze(-1)
# 返回 损失字典 。
# 调用 _get_loss_class 方法计算 分类损失 。
# 调用 _get_loss_bbox 方法计算 边界框损失 。
# 如果提供了掩码参数( masks 和 gt_mask ),还可以调用 _get_loss_mask 方法计算 掩码损失 (这部分代码被注释掉了)。
# 使用 字典解包 ( ** )将 不同类型的损失合并到一个字典中 并返回。
return {
**self._get_loss_class(pred_scores, targets, gt_scores, len(gt_bboxes), postfix),
**self._get_loss_bbox(pred_bboxes, gt_bboxes, postfix),
# **(self._get_loss_mask(masks, gt_mask, match_indices, postfix) if masks is not None and gt_mask is not None else {})
}
# 这段代码的核心功能是计算 DETR 模型的主损失,包括分类损失和边界框损失。它通过以下步骤实现。匹配索引计算:如果未提供匹配索引,则调用匹配器计算预测框和目标框之间的最优匹配。索引提取:使用 _get_index 方法提取批次索引和目标索引。对齐预测框和目标框:确保预测框和目标框一一对应,便于后续损失计算。目标类别标签生成:为目标类别标签和目标分数生成对齐的张量。损失计算:调用 _get_loss_class 和 _get_loss_bbox 方法分别计算分类损失和边界框损失。支持可选的掩码损失计算(当前代码中被注释掉了)。这种设计方式使得损失计算过程清晰高效,同时支持多种任务需求(如目标检测和实例分割)。
# 这段代码定义了 DETRLoss 类的 forward 方法,用于计算 DETR 模型的总损失,包括主损失和可选的辅助损失。
# 定义了 forward 方法,这是 PyTorch 模块的标准前向传播方法。它接收以下参数 :
# 1.pred_bboxes :预测的边界框,形状为 [num_layers, batch_size, num_queries, 4] 。
# 2.pred_scores :预测的类别分数,形状为 [num_layers, batch_size, num_queries, num_classes] 。
# 3.batch :包含目标数据的字典,包括目标类别、目标边界框和每张图片的目标框数量。
# 4.postfix :损失名称的后缀,用于区分不同的损失来源。
# 5.**kwargs :额外的关键字参数,可能包括匹配索引 match_indices 。
def forward(self, pred_bboxes, pred_scores, batch, postfix="", **kwargs):
# 计算预测边界框和分数的损失。
# 参数:
# pred_bboxes (torch.Tensor):预测边界框,形状 [l, b, query, 4]。
# pred_scores (torch.Tensor):预测类别分数,形状 [l, b, query, num_classes]。
# batch (dict):批次信息包含:
# cls (torch.Tensor):地面实况类别,形状 [num_gts]。
# bboxes (torch.Tensor):地面实况边界框,形状 [num_gts, 4]。
# gt_groups (List[int]):批次中每幅图像的地面实况数量。
# postfix (str):损失名称的后缀。
# **kwargs (Any):其他参数,可能包括“match_indices”。
# 返回:
# (dict):计算的损失,包括主要和辅助(如果启用)。
# 注意:
# 使用 pred_bboxes 的最后一个元素和 pred_scores 表示主要损失,其余表示辅助损失(如果 self.aux_loss 为 True)。
"""
Calculate loss for predicted bounding boxes and scores.
Args:
pred_bboxes (torch.Tensor): Predicted bounding boxes, shape [l, b, query, 4].
pred_scores (torch.Tensor): Predicted class scores, shape [l, b, query, num_classes].
batch (dict): Batch information containing:
cls (torch.Tensor): Ground truth classes, shape [num_gts].
bboxes (torch.Tensor): Ground truth bounding boxes, shape [num_gts, 4].
gt_groups (List[int]): Number of ground truths for each image in the batch.
postfix (str): Postfix for loss names.
**kwargs (Any): Additional arguments, may include 'match_indices'.
Returns:
(dict): Computed losses, including main and auxiliary (if enabled).
Note:
Uses last elements of pred_bboxes and pred_scores for main loss, and the rest for auxiliary losses if
self.aux_loss is True.
"""
# 将 pred_bboxes 所在的设备(CPU 或 GPU)赋值给 self.device 。这一步确保后续的所有张量操作都在正确的设备上执行。
self.device = pred_bboxes.device
# 从关键字参数 kwargs 中提取 匹配索引 match_indices 。如果没有提供,则默认为 None 。
match_indices = kwargs.get("match_indices", None)
# 从 batch 字典中提取 目标数据 。
# gt_cls :目标类别。
# gt_bboxes :目标边界框。
# gt_groups :每张图片的目标框数量。
gt_cls, gt_bboxes, gt_groups = batch["cls"], batch["bboxes"], batch["gt_groups"]
# 调用 _get_loss 方法计算 主损失 。
# 使用 最后一层的预测结果 ( pred_bboxes[-1] 和 pred_scores[-1] )计算 分类损失 和 边界框损失 。
# 如果提供了匹配索引 match_indices ,则直接使用;否则, _get_loss 方法会自动计算匹配索引。
# 返回的 total_loss 是一个字典,包含 主损失的各个组成部分 。
total_loss = self._get_loss(
pred_bboxes[-1], pred_scores[-1], gt_bboxes, gt_cls, gt_groups, postfix=postfix, match_indices=match_indices
)
# 如果启用了 辅助损失 ( self.aux_loss 为 True )。
if self.aux_loss:
# 调用 _get_loss_aux 方法计算辅助损失。 使用 除最后一层外的所有中间层的预测结果 ( pred_bboxes[:-1] 和 pred_scores[:-1] )。 将辅助损失的结果更新到 total_loss 字典中。
total_loss.update(
self._get_loss_aux(
pred_bboxes[:-1], pred_scores[:-1], gt_bboxes, gt_cls, gt_groups, match_indices, postfix
)
)
# 返回 包含主损失和辅助损失的字典 total_loss 。
return total_loss
# 这段代码实现了 DETR 模型的总损失计算,包括主损失和可选的辅助损失。它的主要功能包括。主损失计算:使用最后一层的预测结果计算分类损失和边界框损失。支持动态计算匹配索引或使用预计算的匹配索引。辅助损失计算(可选):如果启用了辅助损失,使用中间层的预测结果计算额外的损失。辅助损失通过 _get_loss_aux 方法计算,并更新到总损失中。灵活性:支持通过 postfix 参数区分不同的损失来源。支持通过 kwargs 提供额外的参数(如匹配索引)。设备一致性:通过 self.device 确保所有张量操作都在正确的设备上执行。这种设计方式使得 forward 方法能够高效地计算主损失和辅助损失,同时支持多种任务需求和配置选项。
# DETRLoss 类是一个为 DETR(DEtection TRansformer)模型设计的损失计算模块,旨在通过综合计算分类损失、边界框回归损失和 GIoU 损失,优化目标检测任务的性能。它支持主损失和辅助损失的计算,其中主损失基于最后一层的预测结果,而辅助损失则利用中间层的预测结果,为模型提供额外的监督信号,增强训练效果。此外,该类还支持焦点损失(Focal Loss)和变焦损失(Varifocal Loss)等可选损失函数,以及掩码相关损失(如分割任务中的掩码损失和 Dice 损失)的扩展功能。通过灵活的配置和模块化设计, DETRLoss 类能够适应多种任务需求,为 DETR 模型的训练提供了强大的支持。
3.class RTDETRDetectionLoss(DETRLoss):
# 这段代码定义了 RTDETRDetectionLoss 类,它是 DETRLoss 的一个子类,专门用于 RT-DETR 模型的目标检测任务。它在继承 DETRLoss 的基础上,增加了对去噪(denoising)训练的支持,通过计算去噪损失来增强模型的鲁棒性。
# 定义了一个名为 RTDETRDetectionLoss 的类,继承自 DETRLoss 。这意味着它继承了父类的所有属性和方法,并可以在此基础上进行扩展。
class RTDETRDetectionLoss(DETRLoss):
# 实时 DeepTracker (RT-DETR) 检测损失类,扩展了 DETRLoss。
# 此类计算 RT-DETR 模型的检测损失,其中包括标准检测损失以及提供去噪元数据时的额外去噪训练损失。
"""
Real-Time DeepTracker (RT-DETR) Detection Loss class that extends the DETRLoss.
This class computes the detection loss for the RT-DETR model, which includes the standard detection loss as well as
an additional denoising training loss when provided with denoising metadata.
"""
# 这段代码定义了 RTDETRDetectionLoss 类中的 forward 方法,用于计算 RT-DETR 模型的总损失,包括主损失和可选的去噪(denoising)损失。该方法继承了 DETRLoss 的主损失计算逻辑,并扩展了对去噪训练的支持。
# 定义了 forward 方法,用于计算模型的总损失。它接收以下参数 :
# 1.preds :模型的预测结果,包含预测边界框和预测分数。
# 2.batch :目标数据,包含目标类别、目标边界框和每张图片的目标框数量。
# 3.dn_bboxes 和 dn_scores :去噪预测的边界框和分数(可选)。
# 4.dn_meta :去噪元数据(可选),包含去噪训练所需的信息。
def forward(self, preds, batch, dn_bboxes=None, dn_scores=None, dn_meta=None):
# 前向传递以计算检测损失。
"""
Forward pass to compute the detection loss.
Args:
preds (tuple): Predicted bounding boxes and scores.
batch (dict): Batch data containing ground truth information.
dn_bboxes (torch.Tensor, optional): Denoising bounding boxes. Default is None.
dn_scores (torch.Tensor, optional): Denoising scores. Default is None.
dn_meta (dict, optional): Metadata for denoising. Default is None.
Returns:
(dict): Dictionary containing the total loss and, if applicable, the denoising loss.
"""
# 将输入的 预测结果 preds 解包为 预测边界框 pred_bboxes 和 预测分数 pred_scores 。
pred_bboxes, pred_scores = preds
# 调用父类 DETRLoss 的 forward 方法计算 主损失 。使用 最后一层的预测边界框 和 预测分数 。 返回的 total_loss 是一个字典,包含 主损失的各个组成部分 (如分类损失、边界框损失等)。
total_loss = super().forward(pred_bboxes, pred_scores, batch)
# Check for denoising metadata to compute denoising training loss 检查去噪元数据以计算去噪训练损失。
# 检查 是否提供了去噪元数据 dn_meta 。如果提供了,说明需要计算去噪损失。
if dn_meta is not None:
# 从 dn_meta 中提取 去噪正样本索引 dn_pos_idx 和 去噪组数 dn_num_group 。
dn_pos_idx, dn_num_group = dn_meta["dn_pos_idx"], dn_meta["dn_num_group"]
# 断言 目标组的数量 len(batch["gt_groups"]) 与 去噪正样本索引 的数量一致,确保数据一致性。
assert len(batch["gt_groups"]) == len(dn_pos_idx)
# Get the match indices for denoising 获取去噪的匹配索引。
# 调用 get_dn_match_indices 方法计算 去噪匹配索引 。输入包括 去噪正样本索引 、 去噪组数 和 目标组数量 。 返回的 match_indices 用于 将去噪预测与目标数据对齐 。
match_indices = self.get_dn_match_indices(dn_pos_idx, dn_num_group, batch["gt_groups"])
# Compute the denoising training loss 计算去噪训练损失。
# 调用父类的 forward 方法计算 去噪损失 。使用 去噪预测边界框 dn_bboxes 和 去噪预测分数 dn_scores 。 添加后缀 _dn 以区分主损失和去噪损失。 使用去噪匹配索引 match_indices 确保对齐。
dn_loss = super().forward(dn_bboxes, dn_scores, batch, postfix="_dn", match_indices=match_indices)
# 将计算得到的 去噪损失 更新到总损失字典 total_loss 中。
total_loss.update(dn_loss)
# 如果没有提供 去噪元数据 ( dn_meta 为 None ),则跳过去噪损失的计算。
else:
# If no denoising metadata is provided, set denoising loss to zero 如果没有提供去噪元数据,则将去噪损失设置为零。
# 为 每个主损失项 添加一个对应的 去噪损失项 ,值为零。这一步确保了即使没有去噪训练,损失字典的结构也保持一致。
total_loss.update({f"{k}_dn": torch.tensor(0.0, device=self.device) for k in total_loss.keys()})
# 返回包含 主损失 和 去噪损失 的总损失字典。
return total_loss
# 这段代码实现了 RT-DETR 模型的总损失计算,支持主损失和去噪损失的计算。其核心功能包括。主损失计算:继承自 DETRLoss ,基于最后一层的预测结果计算主损失。去噪损失计算(可选):如果提供了去噪元数据,计算去噪损失。使用 get_dn_match_indices 方法生成去噪匹配索引,确保去噪预测与目标数据对齐。灵活性:如果没有提供去噪元数据,去噪损失将被设置为零,确保代码的兼容性。结构一致性:即使没有去噪训练,损失字典的结构也保持一致,便于后续的优化和调试。这种设计方式使得 RTDETRDetectionLoss 类能够同时支持标准训练和去噪训练,增强了 RT-DETR 模型的鲁棒性和训练效果。
# 这段代码定义了 RTDETRDetectionLoss 类中的静态方法 get_dn_match_indices ,用于生成去噪(denoising)训练所需的匹配索引。这些索引用于将去噪预测与目标数据对齐,从而计算去噪损失。
@staticmethod
# 定义了一个静态方法 get_dn_match_indices ,用于生成去噪匹配索引。它接收以下参数 :
# 1.dn_pos_idx :去噪正样本索引,表示去噪预测中与目标框匹配的部分。
# 2.dn_num_group :去噪组的数量,表示每个目标框对应的去噪预测数量。
# 3.gt_groups :每张图片的目标框数量,用于确定目标框的分组。
def get_dn_match_indices(dn_pos_idx, dn_num_group, gt_groups):
# 获取用于去噪的匹配索引。
"""
Get the match indices for denoising.
Args:
dn_pos_idx (List[torch.Tensor]): List of tensors containing positive indices for denoising.
dn_num_group (int): Number of denoising groups.
gt_groups (List[int]): List of integers representing the number of ground truths for each image.
Returns:
(List[tuple]): List of tuples containing matched indices for denoising.
"""
# 初始化一个空列表 dn_match_indices ,用于 存储每张图片的去噪匹配索引 。
dn_match_indices = []
# 计算目标框的 累积索引 idx_groups 。
# [0, *gt_groups[:-1]] :创建一个列表,包含每张图片的目标框数量,但最后一张图片的数量除外,并在前面添加一个 0。
# .cumsum_(0) :计算累积和,用于 确定每张图片目标框的起始索引 。
idx_groups = torch.as_tensor([0, *gt_groups[:-1]]).cumsum_(0)
# 遍历每张图片的 目标框数量 gt_groups 。
for i, num_gt in enumerate(gt_groups):
# 如果当前图片有目标框( num_gt > 0 ),则进行以下操作:
if num_gt > 0:
# 生成当前图片的 目标框索引 。
# torch.arange(end=num_gt, dtype=torch.long) :生成从 0 到 num_gt - 1 的索引。
# + idx_groups[i] :将 目标框索引 偏移到 当前图片的起始位置 。
gt_idx = torch.arange(end=num_gt, dtype=torch.long) + idx_groups[i]
# 将目标框索引重复 dn_num_group 次,以 匹配去噪预测的数量 。
gt_idx = gt_idx.repeat(dn_num_group)
# 断言 去噪正样本索引 的长度与 目标框索引 的长度一致,确保数据一致性。
assert len(dn_pos_idx[i]) == len(gt_idx), "Expected the same length, "
f"but got {len(dn_pos_idx[i])} and {len(gt_idx)} respectively." # 预期长度相同,但分别得到 {len(dn_pos_idx[i])} 和 {len(gt_idx)}。
# 将 当前图片的 去噪正样本索引 和 目标框索引 作为元组添加到 dn_match_indices 列表中。
dn_match_indices.append((dn_pos_idx[i], gt_idx))
# 如果当前图片没有目标框( num_gt == 0 ),则添加一个 空的匹配索引 ,避免后续计算出错。
else:
dn_match_indices.append((torch.zeros([0], dtype=torch.long), torch.zeros([0], dtype=torch.long)))
# 返回 生成的去噪匹配索引列表 。
return dn_match_indices
# 这段代码的核心功能是生成去噪训练所需的匹配索引,这些索引用于将去噪预测与目标数据对齐。其主要步骤包括。目标框索引生成:根据每张图片的目标框数量,生成目标框的全局索引。索引重复:将目标框索引重复 dn_num_group 次,以匹配去噪预测的数量。数据一致性检查:确保去噪正样本索引与目标框索引的长度一致。空索引处理:如果某张图片没有目标框,添加空的匹配索引,避免后续计算出错。这种设计方式确保了去噪训练的正确性,使得模型能够更好地处理噪声数据,从而提高鲁棒性和性能。
# RTDETRDetectionLoss 类扩展了 DETRLoss ,增加了对去噪训练的支持。它通过以下功能实现。主损失计算:继承自 DETRLoss 的主损失计算逻辑,基于最后一层的预测结果。去噪损失计算:如果提供了去噪元数据,计算去噪损失。使用 get_dn_match_indices 方法生成去噪匹配索引,确保去噪预测与目标数据对齐。灵活性:如果没有提供去噪元数据,去噪损失将被设置为零,确保代码的兼容性。去噪匹配索引生成:静态方法 get_dn_match_indices 用于生成去噪匹配索引,确保去噪训练的正确性。这种设计方式使得 RTDETRDetectionLoss 类能够同时支持标准训练和去噪训练,增强了 RT-DETR 模型的鲁棒性和训练效果。