import torch
import torch.nn as nn
import torch.nn.functional as F
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000):
super(PositionalEncoding, self).__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-torch.log(torch.tensor(10000.0)) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer('pe', pe)
def forward(self, x):
return x + self.pe[:x.size(0), :]
class TransformerModel(nn.Module):
def __init__(self, input_dim, d_model, nhead, nlayers, dim_feedforward, dropout=0.5):
super(TransformerModel, self).__init__()
self.input_dim = input_dim
self.d_model = d_model
self.nhead = nhead
self.nlayers = nlayers
self.dim_feedforward = dim_feedforward
# Embedding层,将输入的每个线段坐标映射到固定维度的向量
self.embedding = nn.Linear(input_dim, d_model)
self.encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout)
self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=nlayers)
self.pos_encoder = PositionalEncoding(d_model)
self.output_linear = nn.Linear(d_model, 3) # 输出长宽高
def forward(self, src):
# 将输入数据展平,形状变为 [batch_size, 24, 8],其中24是线段总数(3视图 * 4线段)
batch_size, num_views, num_segments, _, _ = src.shape
src = src.view(batch_size, -1, self.input_dim) # 展平为 [batch_size, 24, 8]
# 使用embedding层将输入数据映射到固定维度的向量
src = self.embedding(src)
# 添加位置编码
src = self.pos_encoder(src)
# 通过Transformer编码器
output = self.transformer_encoder(src)
# 对序列长度维度取平均
output = output.mean(dim=1)
# 输出线性变换,得到长宽高
output = self.output_linear(output)
return output
# 定义模型参数
input_dim = 8 # 每个线段坐标有8个数值(4个点,每个点2个坐标)
d_model = 128 # Transformer模型的维度
nhead = 8 # 多头注意力的头数
nlayers = 6 # Transformer层数
dim_feedforward = 256 # 前馈网络的维度
# 创建模型
model = TransformerModel(input_dim, d_model, nhead, nlayers, dim_feedforward)
# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# 示例输入数据
input_data = torch.rand(1, 3, 4, 2, 2) # 随机生成输入数据
target_data = torch.tensor([[1.0, 2.0, 3.0]]) # 假设目标长宽高
# 训练模型
model.train()
for epoch in range(100): # 训练100个epoch
optimizer.zero_grad()
output = model(input_data)
loss = criterion(output, target_data)
loss.backward()
optimizer.step()
if epoch % 10 == 0:
print(f"Epoch {epoch}, Loss: {loss.item()}")
# 测试模型
model.eval()
with torch.no_grad():
test_input = torch.rand(1, 3, 4, 2, 2) # 随机生成测试数据
predicted_dimensions = model(test_input)
print(f"Predicted dimensions: {predicted_dimensions}")