MeshGPT 笔记

发布于:2025-06-12 ⋅ 阅读:(20) ⋅ 点赞:(0)

[2311.15475] MeshGPT: Generating Triangle Meshes with Decoder-Only Transformers

https://library.scholarcy.com/try

真正意义上的AI生成三维模型MESHGPT来袭!_哔哩哔哩_bilibili

GitHub - lucidrains/meshgpt-pytorch: Implementation of MeshGPT, SOTA Mesh generation using Attention, in Pytorch

 lucidrains/meshgpt-pytorch | DeepWikimmm


为什么使用 ResNet

使用 ResNet 块在 decoder 中有几个重要原因:

  1. 梯度流动:残差连接帮助解决深层网络中的梯度消失问题,使得训练更加稳定
  2. 特征保持:残差连接确保重要的低层特征能够直接传递到输出层
  3. 重建质量:对于 mesh 重建任务,保持细节特征非常重要,ResNet 的跳跃连接有助于保持这些细节
  4. 训练稳定性:在从量化的离散表示重建连续坐标的过程中,ResNet 提供了更稳定的训练动态

这种设计选择反映了 MeshGPT 需要高质量重建 3D mesh 几何信息的要求,ResNet 架构在这种重建任务中已被证明是有效的。


从量化码到3D坐标的完整路径

重建过程的完整路径可以在decode_from_codes_to_faces方法中看到: meshgpt_pytorch.py:898-947

这个过程展示了ResNet如何参与完整的重建流程:

  1. 量化码通过quantizer.get_output_from_indices转换为连续特征
  2. 这些特征经过ResNet块处理(在decode方法中)
  3. 最终通过to_coor_logits转换为坐标logits
  4. 通过undiscretize函数转换回连续的3D坐标

为什么需要undiscretize

  1. 离散化的逆过程:在编码阶段,连续的3D坐标被离散化为整数索引,解码时需要恢复为连续坐标
  2. 精度恢复:通过精确的数学逆变换,尽可能恢复原始坐标的精度
  3. 范围映射:确保输出坐标在正确的3D空间范围内(通常是-1到1)

图卷积在 MeshAutoencoder 中的使用

MeshAutoencoder 的编码器部分,使用了 SAGEConv(GraphSAGE 卷积)来处理 mesh 数据。具体实现可以看到: meshgpt_pytorch.py:52 meshgpt_pytorch.py:543 meshgpt_pytorch.py:552-560

在编码过程中的应用

encode 方法中,图卷积被用来处理面(face)之间的连接关系: meshgpt_pytorch.py:764-770

这里首先使用初始的 SAGE 卷积,然后通过多个编码器层进行特征提取。

为什么使用图卷积

使用图卷积在 mesh 处理中有几个关键原因:

  1. 自然的数据结构匹配:3D mesh 本质上就是一个图结构,其中面(faces)是节点,面之间的邻接关系构成边。图卷积能够自然地处理这种不规则的拓扑结构。

  2. 邻域信息聚合:SAGEConv 能够聚合每个面的邻居面的特征信息,这对于理解 mesh 的局部几何结构至关重要。

  3. 置换不变性:图卷积对节点的排列顺序不敏感,这对于处理不同拓扑结构的 mesh 很重要。

  4. 特征传播:通过多层图卷积,局部的几何特征可以传播到更大的邻域,帮助模型理解全局的 mesh 结构。

面边缘的自动推导

系统还能自动从面信息推导出面之间的连接关系: data.py:297-302

这确保了图卷积有正确的邻接信息来工作。

Notes

图卷积的使用体现了 MeshGPT 对 3D mesh 数据结构特性的深度理解。相比于传统的卷积神经网络,图卷积更适合处理 mesh 这种不规则的几何数据,能够有效捕获面与面之间的空间关系和几何特征。


你询问的是在 MeshGPT 中一个面(face)有几个邻面。让我基于代码来解释这个问题。

面邻接关系的确定

在 MeshGPT 中,面的邻接关系是通过 derive_face_edges_from_faces 函数来确定的: data.py:297-302

邻面数量的计算逻辑

邻面的数量取决于共享顶点的判断标准: data.py:322-333

关键的判断逻辑在这里:

  1. 默认情况:两个面如果共享至少 2 个顶点,则被认为是邻面(face_edges_vertices_threshold = 2
  2. 可选配置:如果设置 neighbor_if_share_one_vertex = True,则共享 1 个顶点就算邻面

实际邻面数量

对于三角形网格(triangular mesh),一个面的邻面数量通常是:

  • 最多 3 个邻面:因为三角形有 3 条边,每条边最多与一个相邻三角形共享
  • 实际数量可能更少:在网格边界处,面可能只有 1-2 个邻面

这个计算过程在函数中通过以下步骤实现: data.py:324-327

在图卷积中的使用

这些邻接关系被用于图卷积操作中: meshgpt_pytorch.py:764-770


输入网格输出网格? 

是的,MeshGPT 确实可以实现输入网格输出网格的功能,但这个过程分为两个阶段:

重建(Reconstruction)

MeshAutoencoder 可以接收网格数据并重建相同的网格。在训练过程中,它学习将输入网格编码为离散代码,然后解码回原始网格形状: meshgpt_pytorch.py:979-989

训练时输出重建损失,推理时可以返回重建的网格坐标: meshgpt_pytorch.py:1030-1044

生成(Generation)

MeshTransformer 可以生成全新的网格。它使用训练好的 MeshAutoencoder 作为标记器,学习网格的分布模式,然后生成新的网格: meshgpt_pytorch.py:1422-1424

生成过程返回面坐标和面掩码: README.md:87

工作流程

  1. 训练阶段:MeshAutoencoder 学习网格的压缩表示
  2. 生成阶段:MeshTransformer 使用学到的表示生成新网格
  3. 解码阶段:将生成的代码转换回3D网格坐标

整个系统还支持文本条件生成,可以根据文本描述(如"一把椅子")生成相应的3D网格。

Notes

虽然输入和输出都是网格,但 MeshTransformer 生成的是全新的网格,而不是简单的重建。这使得它能够创造出训练数据中没有的新颖3D形状。


import torch
from meshgpt_pytorch import (
    MeshAutoencoder,
    MeshTransformer
)
autoencoder = MeshAutoencoder(
    num_discrete_coors=128
)

transformer = MeshTransformer(
    autoencoder,
    dim=512,
    max_seq_len=12
)

uncached_faces_coors, _ = transformer.generate(cache_kv=False, temperature=0)
cached_faces_coors, _ = transformer.generate(cache_kv=True, temperature=0)

# 打印生成的网格面坐标
print("Uncached Faces Coordinates:")
print(uncached_faces_coors)

print("\nCached Faces Coordinates:")
print(cached_faces_coors)

# 检查两次生成的面坐标是否在数值上接近
assert torch.allclose(uncached_faces_coors, cached_faces_coors)
  1. 结构和维度:生成的面坐标是一个四维张量,形状为 [batch_size, num_faces, num_vertices_per_face, 3]。在这个例子中,batch_size 是 1,num_faces 是 2,每个面有 3 个顶点,每个顶点有 3 个坐标值(x, y, z)。

我看看它画了啥

tensor([[[[-0.3828,  0.8984,  0.1484],
          [-0.2266,  0.0391, -0.3516],
          [ 0.9609, -0.1641, -0.6484]],

         [[-0.0547, -0.7891, -0.7109],
          [-0.5391, -0.5547, -0.2266],
          [ 0.9609,  0.1484, -0.6953]],

         [[ 0.9297, -0.1953,  0.5391],
          [-0.1016,  0.0547, -0.0703],
          [-0.1016, -0.1016, -0.2891]],

         [[-0.7109,  0.2578,  0.5391],
          [-0.1016, -0.5547,  0.7422],
          [-0.2891, -0.1641,  0.8984]],

         [[ 0.4453,  0.2578,  0.5391],
          [-0.1016, -0.5547,  0.0234],
          [-0.2891, -0.1641, -0.3047]],

         [[-0.0547, -0.4609,  0.1484],
          [ 0.0234,  0.0547, -0.3516],
          [-0.2891, -0.9609,  0.2109]],

         [[-0.0547, -0.5391,  0.5391],
          [ 0.3047, -0.7734, -0.9922],
          [-0.2891,  0.1484, -0.3047]],

         [[ 0.2734,  0.1797,  0.1484],
          [ 0.5078, -0.5547, -0.1172],
          [ 0.0859,  0.1484, -0.6953]],

         [[-0.5703,  0.5078,  0.5547],
          [-0.1016,  0.3516, -0.6641],
          [ 0.9609,  0.2266, -0.3047]],

         [[-0.2578,  0.9453,  0.5391],
          [-0.1016,  0.3516, -0.8359],
          [ 0.5859, -0.1641, -0.6953]],

         [[ 0.3516,  0.8047, -0.7109],
          [-0.1016,  0.3516,  0.0234],
          [ 0.9609, -0.1641,  0.4297]],

         [[ 0.4141,  0.2578, -0.5391],
          [-0.1016, -0.5547,  0.7422],
          [ 0.9609,  0.3047, -0.3047]],

         [[-0.0547, -0.3047, -0.7734],
          [-0.3203, -0.5547, -0.8359],
          [-0.5234,  0.3047, -0.3047]],

         [[-0.0547,  0.2578,  0.5391],
          [-0.1016,  0.3516, -0.7266],
          [-0.2891,  0.3047,  0.8984]],

         [[ 0.2734, -0.7578,  0.5391],
          [-0.1016, -0.5547,  0.0234],
          [-0.2891, -0.2891, -0.6953]],

         [[ 0.4141,  0.8047, -0.5391],
          [-0.1016, -0.7734, -0.6641],
          [ 0.9609,  0.4922, -0.6797]],

         [[ 0.2734,  0.2578, -0.5391],
          [-0.1016,  0.0703,  0.1953],
          [ 0.9609,  0.1484, -0.7578]],

         [[-0.0547,  0.2578,  0.1484],
          [-0.1016,  0.3516,  0.9922],
          [-0.2891, -0.3203,  0.2109]],

         [[-0.0547,  0.2578, -0.7734],
          [-0.2266,  0.0391, -0.3516],
          [ 0.5547, -0.6797,  0.8984]],

         [[ 0.4453, -0.3047, -0.1328],
          [-0.3203, -0.5547, -0.0703],
          [ 0.0859,  0.3672, -0.6953]]]])
from OCC.Core.gp import gp_Pnt
from OCC.Core.BRepBuilderAPI import BRepBuilderAPI_MakePolygon, BRepBuilderAPI_MakeFace
from OCC.Display.SimpleGui import init_display

def create_face_from_points(points):
    # 创建顶点
    vertices = [gp_Pnt(*p) for p in points]
    # 创建多边形
    polygon = BRepBuilderAPI_MakePolygon()
    for vertex in vertices:
        polygon.Add(vertex)
    polygon.Close()
    # 创建面
    face = BRepBuilderAPI_MakeFace(polygon.Wire())
    return face.Face()

def draw_faces(points_list):
    display, start_display, add_menu, add_function_to_menu = init_display()
    for points in points_list:
        face = create_face_from_points(points)
        display.DisplayShape(face)
    display.FitAll()
    start_display()

# 你的面的点列表
points_list = [
    [[-0.3828,  0.8984,  0.1484], [-0.2266,  0.0391, -0.3516], [ 0.9609, -0.1641, -0.8484]],
    [[-0.0547, -0.7891, -0.7109], [-0.5391, -0.5547, -0.2266], [ 0.9609,  0.1484, -0.6953]],
    [[ 0.9297, -0.1953,  0.5391], [-0.1016,  0.0547, -0.0703], [-0.1016, -0.1016, -0.2891]],
    [[-0.7109,  0.2578,  0.5391], [-0.1016, -0.5547,  0.7422], [-0.2891, -0.1641,  0.8984]],
    [[ 0.4453,  0.2578,  0.5391], [-0.1016, -0.5547,  0.0234], [-0.2891, -0.1641, -0.3047]],
    [[-0.0547, -0.4609,  0.1484], [ 0.0234,  0.0547, -0.3516], [-0.2891, -0.9609,  0.2109]],
    [[-0.0547, -0.5391,  0.5391], [ 0.3047, -0.7734, -0.9922], [-0.2891,  0.1484, -0.3047]],
    [[ 0.2734,  0.1797,  0.1484], [ 0.4078, -0.5547, -0.1172], [ 0.0859,  0.1484, -0.6953]],
    [[-0.5703,  0.5078,  0.5547], [-0.1016,  0.3516, -0.6641], [ 0.9609,  0.2266, -0.3047]],
    [[-0.2578,  0.9453,  0.5391], [-0.1016,  0.3516, -0.8359], [ 0.5859, -0.1641, -0.6953]],
    [[ 0.3516,  0.8047, -0.7109], [-0.1016,  0.3516,  0.0234], [ 0.9609, -0.1641,  0.4297]],
    [[ 0.4141,  0.2578, -0.5391], [-0.1016, -0.5547,  0.7422], [ 0.9609,  0.3047, -0.3047]],
    [[-0.0547, -0.3047, -0.7734], [-0.3203, -0.5547, -0.8359], [-0.5234,  0.3047, -0.3047]],
    [[-0.0547,  0.2578,  0.5391], [-0.1016,  0.3516, -0.7266], [-0.2891,  0.3047,  0.8984]],
    [[ 0.2734, -0.7578,  0.5391], [-0.1016, -0.5547,  0.0234], [-0.2891, -0.2891, -0.6953]],
    [[ 0.4141,  0.8047, -0.5391], [-0.1016, -0.7734, -0.6641], [ 0.9609,  0.4922, -0.6797]],
    [[ 0.2734,  0.2578, -0.5391], [-0.1016,  0.0703,  0.1953], [ 0.9609,  0.1484, -0.7578]],
    [[ 0.0547,  0.2578,  0.1484], [-0.1016,  0.3516,  0.9922], [-0.2891, -0.3203,  0.2109]],
    [[-0.0547,  0.2578, -0.7734], [-0.2266,  0.0391, -0.3316], [ 0.5547, -0.6797,  0.8984]],
    [[ 0.4453, -0.3047, -0.1328], [-0.3203, -0.5547, -0.0703], [ 0.0859,  0.3672, -0.6953]],
]

draw_faces(points_list)