PETR- Position Embedding Transformation for Multi-View 3D Object Detection

发布于:2025-05-24 ⋅ 阅读:(15) ⋅ 点赞:(0)

旷视 ECCV 2022

纯视觉BEV方案transformer网络3D检测

paper:[2203.05625] PETR: Position Embedding Transformation for Multi-View 3D Object Detection

code:GitHub - megvii-research/PETR: [ECCV2022] PETR: Position Embedding Transformation for Multi-View 3D Object Detection & [ICCV2023] PETRv2: A Unified Framework for 3D Perception from Multi-Camera Images

目标:环视相机的2D特征,加3D位置编码,转成3D表征

  1. 相机视椎空间离散化成栅格坐标点
  2. 坐标点用相机参数转到自车空间3D坐标
  3. 从相机抽特征,加上3D坐标,作为位置encoder输入,输出带3D位置的特征
  4. 上面特征输入transformer decoder,和object query交互,输出检测结果

和DETR3D区别

DETR3D:

Detr3DHead
    __init__
        self.query_embedding = nn.Embedding(self.num_query, self.embed_dims * 2)

    forward
        query_embeds = self.query_embedding.weight
        hs, init_reference, inter_references = self.transformer(
            mlvl_feats,
            query_embeds,
            reg_branches=self.reg_branches if self.with_box_refine else None,  # noqa:E501
            img_metas=img_metas,
        )

Detr3DTransformer
    __init__
        self.embed_dims = self.decoder.embed_dims
        self.reference_points = nn.Linear(self.embed_dims, 3)

    forward(self, mlvl_feats, query_embed, reg_branches=None, **kwargs):
        query_pos, query = torch.split(query_embed, self.embed_dims , dim=1)
        query_pos = query_pos.unsqueeze(0).expand(bs, -1, -1)

        reference_points = self.reference_points(query_pos).sigmoid()

PETR:

PETRHead
    __init__
        num_query=900
        self.embed_dims = 256

        self.reference_points = nn.Embedding(self.num_query, 3)
        nn.init.uniform_(self.reference_points.weight.data, 0, 1)

        self.query_embedding = nn.Sequential(
            nn.Linear(self.embed_dims*3//2, self.embed_dims),
            nn.ReLU(),
            nn.Linear(self.embed_dims, self.embed_dims),
        )

    forward
        reference_points = self.reference_points.weight
        query_embeds = self.query_embedding(pos2posemb3d(reference_points))

        reference_points = reference_points.unsqueeze(0).repeat(batch_size, 1, 1) #.sigmoid()
        
        outs_dec, _ = self.transformer(x, masks, query_embeds, pos_embed, self.reg_branches)

PETRDNTransformer
    __init__
        
    forward(self, x, mask, query_embed, pos_embed, attn_masks=None, reg_branch=None)
        query_embed = query_embed.transpose(0, 1)  # [num_query, dim] -> [num_query, bs, dim]

        out_dec = self.decoder(
            query=target,
            key=memory,
            value=memory,
            key_pos=pos_embed,
            query_pos=query_embed,
            key_padding_mask=mask,
            attn_masks=[attn_masks, None],
            reg_branch=reg_branch,
            )


组成

  1. 生成3D坐标
    1. 输入:栅格坐标点各相机外参
      1. 用CaDDN的LID采样空间
    2. 输出:3D坐标点
    3. 过程:
      1. 计算3D坐标点
      2. min-max正则化
  2. 3D位置encoder
    1. 输入:2D特征,3D坐标点
    2. 输出:3D特征
    3. 过程:
      1. 3D坐标点输入MLP编码,输出3D位置embedding
      2. 2D特征输入1*1卷积,和3D位置embedding相加,得到3D特征
        1. 提特征:ResNet、SwinTR、VoVNetV2
      3. flatten 3D特征
  3. object query
    1. 输入:参考点 nn.Embedding(self.num_query, 3) PETRHead.reference_points
    2. 输出:object query Q0  query_embeds
      1. 过程:正态分布初始化,余弦位置编码(pos2posemb3d),MLP(PETRHead.query_embedding)
  4. decoder
    1. 包含mha,ffn,训练中object query提取出障碍物高维特征
  5. head,loss函数
    1. 分类头:出障碍物类别,focal loss
    2. 框头:输出相对于参考点的偏移量,L1 loss
    3. Hungarian匹配gt
    4. ,和DETR3D一样