旷视 ECCV 2022
纯视觉BEV方案transformer网络3D检测
paper:[2203.05625] PETR: Position Embedding Transformation for Multi-View 3D Object Detection
目标:环视相机的2D特征,加3D位置编码,转成3D表征
- 相机视椎空间离散化成栅格坐标点
- 坐标点用相机参数转到自车空间3D坐标
- 从相机抽特征,加上3D坐标,作为位置encoder输入,输出带3D位置的特征
- 上面特征输入transformer decoder,和object query交互,输出检测结果
和DETR3D区别
DETR3D:
Detr3DHead
__init__
self.query_embedding = nn.Embedding(self.num_query, self.embed_dims * 2)
forward
query_embeds = self.query_embedding.weight
hs, init_reference, inter_references = self.transformer(
mlvl_feats,
query_embeds,
reg_branches=self.reg_branches if self.with_box_refine else None, # noqa:E501
img_metas=img_metas,
)
Detr3DTransformer
__init__
self.embed_dims = self.decoder.embed_dims
self.reference_points = nn.Linear(self.embed_dims, 3)
forward(self, mlvl_feats, query_embed, reg_branches=None, **kwargs):
query_pos, query = torch.split(query_embed, self.embed_dims , dim=1)
query_pos = query_pos.unsqueeze(0).expand(bs, -1, -1)
reference_points = self.reference_points(query_pos).sigmoid()
PETR:
PETRHead
__init__
num_query=900
self.embed_dims = 256
self.reference_points = nn.Embedding(self.num_query, 3)
nn.init.uniform_(self.reference_points.weight.data, 0, 1)
self.query_embedding = nn.Sequential(
nn.Linear(self.embed_dims*3//2, self.embed_dims),
nn.ReLU(),
nn.Linear(self.embed_dims, self.embed_dims),
)
forward
reference_points = self.reference_points.weight
query_embeds = self.query_embedding(pos2posemb3d(reference_points))
reference_points = reference_points.unsqueeze(0).repeat(batch_size, 1, 1) #.sigmoid()
outs_dec, _ = self.transformer(x, masks, query_embeds, pos_embed, self.reg_branches)
PETRDNTransformer
__init__
forward(self, x, mask, query_embed, pos_embed, attn_masks=None, reg_branch=None)
query_embed = query_embed.transpose(0, 1) # [num_query, dim] -> [num_query, bs, dim]
out_dec = self.decoder(
query=target,
key=memory,
value=memory,
key_pos=pos_embed,
query_pos=query_embed,
key_padding_mask=mask,
attn_masks=[attn_masks, None],
reg_branch=reg_branch,
)
组成
- 生成3D坐标
- 输入:栅格坐标点
各相机外参
- 用CaDDN的LID采样空间
- 输出:3D坐标点
- 过程:
- 计算3D坐标点
- min-max正则化
- 计算3D坐标点
- 输入:栅格坐标点
- 3D位置encoder
- 输入:2D特征
,3D坐标点
- 输出:3D特征
- 过程:
- 3D坐标点输入MLP编码,输出3D位置embedding
- 2D特征输入1*1卷积,和3D位置embedding相加,得到3D特征
- 提特征:ResNet、SwinTR、VoVNetV2
- flatten 3D特征
- 输入:2D特征
- object query
- 输入:参考点 nn.Embedding(self.num_query, 3) PETRHead.reference_points
- 输出:object query Q0 query_embeds
- 过程:正态分布初始化,余弦位置编码(pos2posemb3d),MLP(PETRHead.query_embedding)
- decoder
- 包含mha,ffn,训练中object query提取出障碍物高维特征
- head,loss函数
- 分类头:出障碍物类别,focal loss
- 框头:输出相对于参考点的偏移量,L1 loss
- Hungarian匹配gt
,和DETR3D一样