Point Transformer V3(PTv3)【3:上采样unpooling】

发布于:2025-08-30 ⋅ 阅读:(18) ⋅ 点赞:(0)

PTV3专题目录

序列化编码

降采样SerializedPooling

上采样SerializedUnpooling

背景

点云分割的原始代码

class SerializedUnpooling(PointModule):
    def __init__(
        self,
        in_channels,
        skip_channels,
        out_channels,
        norm_layer=None,
        act_layer=None,
        traceable=False,  # record parent and cluster
    ):
        super().__init__()
        self.proj = PointSequential(nn.Linear(in_channels, out_channels))
        self.proj_skip = PointSequential(nn.Linear(skip_channels, out_channels))

        if norm_layer is not None:
            self.proj.add(norm_layer(out_channels))
            self.proj_skip.add(norm_layer(out_channels))

        if act_layer is not None:
            self.proj.add(act_layer())
            self.proj_skip.add(act_layer())

        self.traceable = traceable

    def forward(self, point):
        assert "pooling_parent" in point.keys()
        assert "pooling_inverse" in point.keys()
        parent = point.pop("pooling_parent")
        inverse = point.pop("pooling_inverse")
        point = self.proj(point)
        parent = self.proj_skip(parent)
        parent.feat = parent.feat + point.feat[inverse]

        if self.traceable:
            parent["unpooling_parent"] = point
        return parent

基本功能

SerializedUnpooling 的原理可以概括为:

利用池化时记录的“父子关系”(pooling_parentpooling_inverse),将低分辨率点云的特征精准地“广播”回高分辨率的空间结构中,并与该层的跳跃连接特征相加,实现信息的上采样和融合。

这个过程完全依赖于 SerializedPooling 提供的溯源信息,两者构成了一个高效且完全可逆的编解码对。

具体思路

我们来详细解析 SerializedUnpooling 的工作原理。它是 SerializedPooling 的逆操作,在典型的 U-Net 架构中扮演着上采样(Upsampling)特征融合的关键角色。

1. 总体目标

SerializedUnpooling 的目标是将在网络深层、低分辨率的点云特征图(点数少,但语义信息丰富)恢复到其在网络浅层时的高分辨率(点数多,但空间细节丰富),并融合来自浅层的特征。

简单来说,它要回答这个问题:如何将一个点的特征“分配”回当初合并成它的那 N 个点?

2. 核心原理:利用池化时保存的信息

SerializedUnpooling 的“魔法”完全依赖于 SerializedPooling 在执行下采样时,有先见之明地保存了两个关键信息:

  1. pooling_parent: 这是一个指向池化前的、高分辨率的 Point 对象的引用。它包含了原始的点云坐标、特征(也就是跳跃连接 (Skip Connection) 的特征)以及所有序列化信息。
  2. pooling_inverse: 这是一个索引张量。如果池化前的点云有 N 个点,池化后有 M 个点,那么 pooling_inverse 就是一个长度为 N 的张量。它的第 i 个元素的值 j 表示:原始的第 i 个点在池化时被合并到了新的第 j 个点中。

3. __init__ (初始化)

在创建 SerializedUnpooling 实例时,会定义几个关键部分:

  • in_channels: 输入的低分辨率点云的特征维度。
  • skip_channels: 来自 pooling_parent (跳跃连接) 的高分辨率点云的特征维度。
  • out_channels: 最终输出的高分辨率点云的特征维度。
  • self.proj: 一个线性层,用于处理来自深层网络(低分辨率)的特征,将其维度从 in_channels 映射到 out_channels
  • self.proj_skip: 另一个线性层,用于处理来自跳跃连接(高分辨率)的特征,将其维度从 skip_channels 映射到 out_channels

4. forward (前向传播) - 详细步骤

我们用一个具体的例子来贯穿整个流程。假设 SerializedPooling 将一个 14 个点的点云(我们称之为 P_high)池化成了一个 2 个点的点云(我们称之为 P_low)。

现在,SerializedUnpooling 的输入 point 就是 P_low

  1. 断言和信息恢复:

    • assert "pooling_parent" in point.keys(): 检查 P_low 是否保存了指向 P_high 的引用。
    • parent = point.pop("pooling_parent"): 取出这个引用,现在 parent 就是 P_high(包含14个点及其原始特征)。
    • inverse = point.pop("pooling_inverse"): 取出那个长度为 14 的索引张量,其内容类似 [0,0,0,0,0,0,0,0, 1,1,1,1,1,1]
  2. 特征投影:

    • point = self.proj(point): 将 P_low 的特征(2个点)通过线性层进行变换。假设维度从 (2, in_channels) 变为 (2, out_channels)
    • parent = self.proj_skip(parent): 将 P_high 的跳跃连接特征(14个点)也通过线性层进行变换。维度从 (14, skip_channels) 变为 (14, out_channels)
    • 现在,深层特征和浅层特征都有了相同的维度 out_channels,为相加做好了准备。
  3. 核心操作:特征广播 (Broadcast) 与融合:

    • 这是最关键的一步:parent.feat = parent.feat + point.feat[inverse]
    • 我们来分解 point.feat[inverse]
      • point.feat 是一个 (2, out_channels) 的张量,包含两个低分辨率点的特征。
      • inverse 是一个 (14,) 的索引张量。
      • point.feat[inverse] 是一个高级索引操作。它会根据 inverse 中的值,为 parent 中的每一个点(共14个)从 point.feat 中“拾取”对应的特征。
      • 由于 inverse 的前8个值是0,所以新生成张量的前8行都会是 point.feat 的第0行。
      • 由于 inverse 的后6个值是1,所以新生成张量的后6行都会是 point.feat 的第1行。
      • 最终,point.feat[inverse] 的结果是一个 (14, out_channels) 的张量。这个操作高效地将低分辨率的特征广播回了原始高分辨率点云的每个点上。
    • parent.feat + ...: 将广播后的特征与 parent 自身经过投影的特征进行逐元素相加。这就完成了深层语义信息和浅层空间细节的融合。
  4. 返回结果:

    • 函数返回更新后的 parent 对象。它现在是一个高分辨率的点云(14个点),但其特征已经融合了来自网络深层的信息。

QA

一、point.feat[inverse]的实现,其实就是根据inverse把特征进行复制操作吧

是的,理解非常准确!

point.feat[inverse] 的实现,本质上就是一次高效的**复制和广播(Broadcast)**操作。它利用 inverse 张量作为索引,将低分辨率的特征“复制”并“扩展”到高分辨率的点云结构上。

用之前的例子来具体说明这个过程:

  1. 输入状态:

    • point.feat: 这是低分辨率点云的特征张量。假设池化后剩下 2 个点,特征维度为 64。那么 point.feat 的形状是 (2, 64)。我们可以把它看作 [特征_A, 特征_B]
    • inverse: 这是在池化时生成的索引张量,记录了每个原始点属于哪个新聚类。它的长度是高分辨率点云的点数,比如 14。它的内容是 [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]
  2. point.feat[inverse] 操作:

    • 这是一个 PyTorch 的高级索引(Advanced Indexing)操作。
    • 它会创建一个新的形状为 (14, 64) 的张量。
    • 它遍历 inverse 中的每一个元素:
      • 对于 inverse 的前 8 个元素,值都是 0。所以,它会取出 point.feat 中索引为 0 的行(也就是 特征_A),并把这个特征复制 8 次,作为新张量的前 8 行。
      • 对于 inverse 的后 6 个元素,值都是 1。所以,它会取出 point.feat 中索引为 1 的行(也就是 特征_B),并把这个特征复制 6 次,作为新张量的后 6 行。
  3. 结果:

    • 最终 point.feat[inverse] 的结果是一个 (14, 64) 的张量,其内容看起来像:
      [
       特征_A,  // 对应原始点1
       特征_A,  // 对应原始点2
       ... (共8行)
       特征_A,  // 对应原始点8
       特征_B,  // 对应原始点9
       ... (共6行)
       特征_B   // 对应原始点14
      ]
      

总结来说,point.feat[inverse] 这行代码用一种极其高效和向量化的方式,完成了“将一个点的特征广播给所有属于它的子点”这一任务,是实现上采样的核心步骤。


网站公告

今日签到

点亮在社区的每一天
去签到