[2311.15475] MeshGPT: Generating Triangle Meshes with Decoder-Only Transformers
https://library.scholarcy.com/try
真正意义上的AI生成三维模型MESHGPT来袭!_哔哩哔哩_bilibili
lucidrains/meshgpt-pytorch | DeepWikimmm
为什么使用 ResNet
使用 ResNet 块在 decoder 中有几个重要原因:
- 梯度流动:残差连接帮助解决深层网络中的梯度消失问题,使得训练更加稳定
- 特征保持:残差连接确保重要的低层特征能够直接传递到输出层
- 重建质量:对于 mesh 重建任务,保持细节特征非常重要,ResNet 的跳跃连接有助于保持这些细节
- 训练稳定性:在从量化的离散表示重建连续坐标的过程中,ResNet 提供了更稳定的训练动态
这种设计选择反映了 MeshGPT 需要高质量重建 3D mesh 几何信息的要求,ResNet 架构在这种重建任务中已被证明是有效的。
从量化码到3D坐标的完整路径
重建过程的完整路径可以在decode_from_codes_to_faces
方法中看到: meshgpt_pytorch.py:898-947
这个过程展示了ResNet如何参与完整的重建流程:
- 量化码通过
quantizer.get_output_from_indices
转换为连续特征 - 这些特征经过ResNet块处理(在
decode
方法中) - 最终通过
to_coor_logits
转换为坐标logits - 通过
undiscretize
函数转换回连续的3D坐标
为什么需要undiscretize
- 离散化的逆过程:在编码阶段,连续的3D坐标被离散化为整数索引,解码时需要恢复为连续坐标
- 精度恢复:通过精确的数学逆变换,尽可能恢复原始坐标的精度
- 范围映射:确保输出坐标在正确的3D空间范围内(通常是-1到1)
图卷积在 MeshAutoencoder 中的使用
在 MeshAutoencoder
的编码器部分,使用了 SAGEConv(GraphSAGE 卷积)来处理 mesh 数据。具体实现可以看到: meshgpt_pytorch.py:52 meshgpt_pytorch.py:543 meshgpt_pytorch.py:552-560
在编码过程中的应用
在 encode
方法中,图卷积被用来处理面(face)之间的连接关系: meshgpt_pytorch.py:764-770
这里首先使用初始的 SAGE 卷积,然后通过多个编码器层进行特征提取。
为什么使用图卷积
使用图卷积在 mesh 处理中有几个关键原因:
自然的数据结构匹配:3D mesh 本质上就是一个图结构,其中面(faces)是节点,面之间的邻接关系构成边。图卷积能够自然地处理这种不规则的拓扑结构。
邻域信息聚合:SAGEConv 能够聚合每个面的邻居面的特征信息,这对于理解 mesh 的局部几何结构至关重要。
置换不变性:图卷积对节点的排列顺序不敏感,这对于处理不同拓扑结构的 mesh 很重要。
特征传播:通过多层图卷积,局部的几何特征可以传播到更大的邻域,帮助模型理解全局的 mesh 结构。
面边缘的自动推导
系统还能自动从面信息推导出面之间的连接关系: data.py:297-302
这确保了图卷积有正确的邻接信息来工作。
Notes
图卷积的使用体现了 MeshGPT 对 3D mesh 数据结构特性的深度理解。相比于传统的卷积神经网络,图卷积更适合处理 mesh 这种不规则的几何数据,能够有效捕获面与面之间的空间关系和几何特征。
你询问的是在 MeshGPT 中一个面(face)有几个邻面。让我基于代码来解释这个问题。
面邻接关系的确定
在 MeshGPT 中,面的邻接关系是通过 derive_face_edges_from_faces
函数来确定的: data.py:297-302
邻面数量的计算逻辑
邻面的数量取决于共享顶点的判断标准: data.py:322-333
关键的判断逻辑在这里:
- 默认情况:两个面如果共享至少 2 个顶点,则被认为是邻面(
face_edges_vertices_threshold = 2
) - 可选配置:如果设置
neighbor_if_share_one_vertex = True
,则共享 1 个顶点就算邻面
实际邻面数量
对于三角形网格(triangular mesh),一个面的邻面数量通常是:
- 最多 3 个邻面:因为三角形有 3 条边,每条边最多与一个相邻三角形共享
- 实际数量可能更少:在网格边界处,面可能只有 1-2 个邻面
这个计算过程在函数中通过以下步骤实现: data.py:324-327
在图卷积中的使用
这些邻接关系被用于图卷积操作中: meshgpt_pytorch.py:764-770
输入网格输出网格?
是的,MeshGPT 确实可以实现输入网格输出网格的功能,但这个过程分为两个阶段:
重建(Reconstruction)
MeshAutoencoder 可以接收网格数据并重建相同的网格。在训练过程中,它学习将输入网格编码为离散代码,然后解码回原始网格形状: meshgpt_pytorch.py:979-989
训练时输出重建损失,推理时可以返回重建的网格坐标: meshgpt_pytorch.py:1030-1044
生成(Generation)
MeshTransformer 可以生成全新的网格。它使用训练好的 MeshAutoencoder 作为标记器,学习网格的分布模式,然后生成新的网格: meshgpt_pytorch.py:1422-1424
生成过程返回面坐标和面掩码: README.md:87
工作流程
- 训练阶段:MeshAutoencoder 学习网格的压缩表示
- 生成阶段:MeshTransformer 使用学到的表示生成新网格
- 解码阶段:将生成的代码转换回3D网格坐标
整个系统还支持文本条件生成,可以根据文本描述(如"一把椅子")生成相应的3D网格。
Notes
虽然输入和输出都是网格,但 MeshTransformer 生成的是全新的网格,而不是简单的重建。这使得它能够创造出训练数据中没有的新颖3D形状。
import torch
from meshgpt_pytorch import (
MeshAutoencoder,
MeshTransformer
)
autoencoder = MeshAutoencoder(
num_discrete_coors=128
)
transformer = MeshTransformer(
autoencoder,
dim=512,
max_seq_len=12
)
uncached_faces_coors, _ = transformer.generate(cache_kv=False, temperature=0)
cached_faces_coors, _ = transformer.generate(cache_kv=True, temperature=0)
# 打印生成的网格面坐标
print("Uncached Faces Coordinates:")
print(uncached_faces_coors)
print("\nCached Faces Coordinates:")
print(cached_faces_coors)
# 检查两次生成的面坐标是否在数值上接近
assert torch.allclose(uncached_faces_coors, cached_faces_coors)
结构和维度:生成的面坐标是一个四维张量,形状为
[batch_size, num_faces, num_vertices_per_face, 3]
。在这个例子中,batch_size
是 1,num_faces
是 2,每个面有 3 个顶点,每个顶点有 3 个坐标值(x, y, z)。
我看看它画了啥
tensor([[[[-0.3828, 0.8984, 0.1484],
[-0.2266, 0.0391, -0.3516],
[ 0.9609, -0.1641, -0.6484]],
[[-0.0547, -0.7891, -0.7109],
[-0.5391, -0.5547, -0.2266],
[ 0.9609, 0.1484, -0.6953]],
[[ 0.9297, -0.1953, 0.5391],
[-0.1016, 0.0547, -0.0703],
[-0.1016, -0.1016, -0.2891]],
[[-0.7109, 0.2578, 0.5391],
[-0.1016, -0.5547, 0.7422],
[-0.2891, -0.1641, 0.8984]],
[[ 0.4453, 0.2578, 0.5391],
[-0.1016, -0.5547, 0.0234],
[-0.2891, -0.1641, -0.3047]],
[[-0.0547, -0.4609, 0.1484],
[ 0.0234, 0.0547, -0.3516],
[-0.2891, -0.9609, 0.2109]],
[[-0.0547, -0.5391, 0.5391],
[ 0.3047, -0.7734, -0.9922],
[-0.2891, 0.1484, -0.3047]],
[[ 0.2734, 0.1797, 0.1484],
[ 0.5078, -0.5547, -0.1172],
[ 0.0859, 0.1484, -0.6953]],
[[-0.5703, 0.5078, 0.5547],
[-0.1016, 0.3516, -0.6641],
[ 0.9609, 0.2266, -0.3047]],
[[-0.2578, 0.9453, 0.5391],
[-0.1016, 0.3516, -0.8359],
[ 0.5859, -0.1641, -0.6953]],
[[ 0.3516, 0.8047, -0.7109],
[-0.1016, 0.3516, 0.0234],
[ 0.9609, -0.1641, 0.4297]],
[[ 0.4141, 0.2578, -0.5391],
[-0.1016, -0.5547, 0.7422],
[ 0.9609, 0.3047, -0.3047]],
[[-0.0547, -0.3047, -0.7734],
[-0.3203, -0.5547, -0.8359],
[-0.5234, 0.3047, -0.3047]],
[[-0.0547, 0.2578, 0.5391],
[-0.1016, 0.3516, -0.7266],
[-0.2891, 0.3047, 0.8984]],
[[ 0.2734, -0.7578, 0.5391],
[-0.1016, -0.5547, 0.0234],
[-0.2891, -0.2891, -0.6953]],
[[ 0.4141, 0.8047, -0.5391],
[-0.1016, -0.7734, -0.6641],
[ 0.9609, 0.4922, -0.6797]],
[[ 0.2734, 0.2578, -0.5391],
[-0.1016, 0.0703, 0.1953],
[ 0.9609, 0.1484, -0.7578]],
[[-0.0547, 0.2578, 0.1484],
[-0.1016, 0.3516, 0.9922],
[-0.2891, -0.3203, 0.2109]],
[[-0.0547, 0.2578, -0.7734],
[-0.2266, 0.0391, -0.3516],
[ 0.5547, -0.6797, 0.8984]],
[[ 0.4453, -0.3047, -0.1328],
[-0.3203, -0.5547, -0.0703],
[ 0.0859, 0.3672, -0.6953]]]])
from OCC.Core.gp import gp_Pnt
from OCC.Core.BRepBuilderAPI import BRepBuilderAPI_MakePolygon, BRepBuilderAPI_MakeFace
from OCC.Display.SimpleGui import init_display
def create_face_from_points(points):
# 创建顶点
vertices = [gp_Pnt(*p) for p in points]
# 创建多边形
polygon = BRepBuilderAPI_MakePolygon()
for vertex in vertices:
polygon.Add(vertex)
polygon.Close()
# 创建面
face = BRepBuilderAPI_MakeFace(polygon.Wire())
return face.Face()
def draw_faces(points_list):
display, start_display, add_menu, add_function_to_menu = init_display()
for points in points_list:
face = create_face_from_points(points)
display.DisplayShape(face)
display.FitAll()
start_display()
# 你的面的点列表
points_list = [
[[-0.3828, 0.8984, 0.1484], [-0.2266, 0.0391, -0.3516], [ 0.9609, -0.1641, -0.8484]],
[[-0.0547, -0.7891, -0.7109], [-0.5391, -0.5547, -0.2266], [ 0.9609, 0.1484, -0.6953]],
[[ 0.9297, -0.1953, 0.5391], [-0.1016, 0.0547, -0.0703], [-0.1016, -0.1016, -0.2891]],
[[-0.7109, 0.2578, 0.5391], [-0.1016, -0.5547, 0.7422], [-0.2891, -0.1641, 0.8984]],
[[ 0.4453, 0.2578, 0.5391], [-0.1016, -0.5547, 0.0234], [-0.2891, -0.1641, -0.3047]],
[[-0.0547, -0.4609, 0.1484], [ 0.0234, 0.0547, -0.3516], [-0.2891, -0.9609, 0.2109]],
[[-0.0547, -0.5391, 0.5391], [ 0.3047, -0.7734, -0.9922], [-0.2891, 0.1484, -0.3047]],
[[ 0.2734, 0.1797, 0.1484], [ 0.4078, -0.5547, -0.1172], [ 0.0859, 0.1484, -0.6953]],
[[-0.5703, 0.5078, 0.5547], [-0.1016, 0.3516, -0.6641], [ 0.9609, 0.2266, -0.3047]],
[[-0.2578, 0.9453, 0.5391], [-0.1016, 0.3516, -0.8359], [ 0.5859, -0.1641, -0.6953]],
[[ 0.3516, 0.8047, -0.7109], [-0.1016, 0.3516, 0.0234], [ 0.9609, -0.1641, 0.4297]],
[[ 0.4141, 0.2578, -0.5391], [-0.1016, -0.5547, 0.7422], [ 0.9609, 0.3047, -0.3047]],
[[-0.0547, -0.3047, -0.7734], [-0.3203, -0.5547, -0.8359], [-0.5234, 0.3047, -0.3047]],
[[-0.0547, 0.2578, 0.5391], [-0.1016, 0.3516, -0.7266], [-0.2891, 0.3047, 0.8984]],
[[ 0.2734, -0.7578, 0.5391], [-0.1016, -0.5547, 0.0234], [-0.2891, -0.2891, -0.6953]],
[[ 0.4141, 0.8047, -0.5391], [-0.1016, -0.7734, -0.6641], [ 0.9609, 0.4922, -0.6797]],
[[ 0.2734, 0.2578, -0.5391], [-0.1016, 0.0703, 0.1953], [ 0.9609, 0.1484, -0.7578]],
[[ 0.0547, 0.2578, 0.1484], [-0.1016, 0.3516, 0.9922], [-0.2891, -0.3203, 0.2109]],
[[-0.0547, 0.2578, -0.7734], [-0.2266, 0.0391, -0.3316], [ 0.5547, -0.6797, 0.8984]],
[[ 0.4453, -0.3047, -0.1328], [-0.3203, -0.5547, -0.0703], [ 0.0859, 0.3672, -0.6953]],
]
draw_faces(points_list)