transformer 输入三视图线段输出长宽高 笔记

发布于:2025-06-01 ⋅ 阅读:(24) ⋅ 点赞:(0)
import torch
import torch.nn as nn
import torch.nn.functional as F

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-torch.log(torch.tensor(10000.0)) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:x.size(0), :]

class TransformerModel(nn.Module):
    def __init__(self, input_dim, d_model, nhead, nlayers, dim_feedforward, dropout=0.5):
        super(TransformerModel, self).__init__()
        self.input_dim = input_dim
        self.d_model = d_model
        self.nhead = nhead
        self.nlayers = nlayers
        self.dim_feedforward = dim_feedforward

        # Embedding层,将输入的每个线段坐标映射到固定维度的向量
        self.embedding = nn.Linear(input_dim, d_model)

        self.encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout)
        self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=nlayers)
        self.pos_encoder = PositionalEncoding(d_model)
        self.output_linear = nn.Linear(d_model, 3)  # 输出长宽高

    def forward(self, src):
        # 将输入数据展平,形状变为 [batch_size, 24, 8],其中24是线段总数(3视图 * 4线段)
        batch_size, num_views, num_segments, _, _ = src.shape
        src = src.view(batch_size, -1, self.input_dim)  # 展平为 [batch_size, 24, 8]

        # 使用embedding层将输入数据映射到固定维度的向量
        src = self.embedding(src)

        # 添加位置编码
        src = self.pos_encoder(src)

        # 通过Transformer编码器
        output = self.transformer_encoder(src)

        # 对序列长度维度取平均
        output = output.mean(dim=1)

        # 输出线性变换,得到长宽高
        output = self.output_linear(output)

        return output

# 定义模型参数
input_dim = 8  # 每个线段坐标有8个数值(4个点,每个点2个坐标)
d_model = 128  # Transformer模型的维度
nhead = 8  # 多头注意力的头数
nlayers = 6  # Transformer层数
dim_feedforward = 256  # 前馈网络的维度

# 创建模型
model = TransformerModel(input_dim, d_model, nhead, nlayers, dim_feedforward)

# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# 示例输入数据
input_data = torch.rand(1, 3, 4, 2, 2)  # 随机生成输入数据
target_data = torch.tensor([[1.0, 2.0, 3.0]])  # 假设目标长宽高

# 训练模型
model.train()
for epoch in range(100):  # 训练100个epoch
    optimizer.zero_grad()
    output = model(input_data)
    loss = criterion(output, target_data)
    loss.backward()
    optimizer.step()
    if epoch % 10 == 0:
        print(f"Epoch {epoch}, Loss: {loss.item()}")

# 测试模型
model.eval()
with torch.no_grad():
    test_input = torch.rand(1, 3, 4, 2, 2)  # 随机生成测试数据
    predicted_dimensions = model(test_input)
    print(f"Predicted dimensions: {predicted_dimensions}")