TensorBoard

发布于:2025-04-06 ⋅ 阅读:(14) ⋅ 点赞:(0)

以下是 TensorBoard 在 PyTorch 中的使用指南,涵盖安装、基础操作和高级功能,帮助你高效监控和可视化模型训练过程。


1. 安装与验证

安装
pip install tensorboard
验证安装
tensorboard --version  # 输出版本号,例如:2.12.0

2. 基础使用流程

步骤 1:导入 SummaryWriter
from torch.utils.tensorboard import SummaryWriter

# 创建 writer 对象,指定日志保存目录(默认:runs/当前时间)
writer = SummaryWriter("logs")  # 日志会保存在 ./logs 文件夹中
步骤 2:记录数据

记录标量(损失、准确率等)

for epoch in range(100):
    loss = 0.1 * (100 - epoch)  # 模拟损失值
    accuracy = 0.01 * epoch     # 模拟准确率
    
    # 记录单指标
    writer.add_scalar("Loss/train", loss, epoch)
    # 记录多指标(同一图表)
    writer.add_scalars("Metrics", {"train_loss": loss, "train_acc": accuracy}, epoch)

记录模型结构

model = ...  # 你的PyTorch模型
dummy_input = torch.randn(1, 3, 224, 224)  # 输入样例(batch_size=1, 3通道, 224x224图像)
writer.add_graph(model, dummy_input)  # 生成计算图

记录图像/特征图

images = torch.randn(8, 3, 224, 224)  # 模拟一批图像
writer.add_images("Training Samples", images, epoch)  # 记录图像批次

# 记录卷积层特征图(假设features是中间层的输出)
features = model.conv_layers(dummy_input)
writer.add_image("Feature Maps", features[0], epoch, dataformats="HW")  # 单通道特征图

记录直方图(权重分布)

for name, param in model.named_parameters():
    writer.add_histogram(f"Parameters/{name}", param, epoch)

3. 启动 TensorBoard 服务

在终端中运行以下命令(注意路径匹配):

tensorboard --logdir=logs --port=6006
  • --logdir:指定日志目录(与 SummaryWriter 的路径一致)。
  • --port:指定端口(默认6006,若冲突可改为其他端口如6007)。

访问浏览器:http://localhost:6006(或远程服务器IP:端口)。


4. 核心功能详解

Scalars(标量)
  • 监控训练/验证损失、准确率、学习率等指标。
  • 技巧:使用 / 命名层级(如 Loss/trainLoss/val),TensorBoard会自动分组。
Graphs(模型结构)
  • 可视化模型计算图,检查数据流和层连接。
  • 注意:确保 add_graph 的输入张量形状与实际数据一致。
Images(图像)
  • 查看输入数据、数据增强效果或中间特征图。
  • 支持格式:单张图像(add_image)或批次图像(add_images)。
Histograms(直方图)
  • 分析权重/偏置的分布变化,检测梯度消失或爆炸。
PR Curves & ROC(分类任务)
  • 记录精确率-召回率曲线或ROC曲线:
from torchmetrics import PrecisionRecallCurve
pr_curve = PrecisionRecallCurve(task="binary")
precision, recall, _ = pr_curve(predictions, labels)
writer.add_pr_curve("PR Curve", labels, predictions, epoch)

5. 高级功能

Embedding Projector(降维可视化)
# 记录嵌入向量(如特征提取后的高维数据)
embeddings = model.get_embeddings(data)  # 假设输出形状 [N, 512]
writer.add_embedding(embeddings, metadata=labels, label_img=images, global_step=epoch)
  • 在 TensorBoard 的 Projector 标签页中查看PCA/t-SNE降维结果。
Hyperparameter Tuning(超参数对比)
# 记录超参数和对应结果
writer.add_hparams(
    {"lr": 0.01, "batch_size": 32},
    {"hparam/accuracy": 0.95, "hparam/loss": 0.1},
)

6. 常见问题

Q1:TensorBoard 页面无数据?
  • 检查 --logdir 路径是否与 SummaryWriter 的路径一致。
  • 确保数据已写入日志(调用 writer.flush() 或关闭 writer)。
Q2:如何远程访问 TensorBoard?

在服务器运行:

tensorboard --logdir=logs --port=6006 --bind_all

本地通过 ssh 转发端口:

ssh -L 6006:localhost:6006 user@server_ip
Q3:日志文件过大?
  • 定期清理旧日志或按实验分目录保存(如 logs/exp1, logs/exp2)。
  • 使用 writer.close() 确保资源释放。

7. 完整代码示例

import torch
from torch.utils.tensorboard import SummaryWriter
from torchvision.models import resnet18

# 初始化
writer = SummaryWriter("logs")
model = resnet18(pretrained=False)
dummy_input = torch.randn(1, 3, 224, 224)

# 记录模型结构
writer.add_graph(model, dummy_input)

# 模拟训练循环
for epoch in range(100):
    loss = 0.1 * (100 - epoch)
    accuracy = 0.01 * epoch
    
    # 记录标量
    writer.add_scalar("Loss/train", loss, epoch)
    writer.add_scalars("Metrics", {"train_acc": accuracy}, epoch)
    
    # 记录直方图
    for name, param in model.named_parameters():
        writer.add_histogram(f"Params/{name}", param, epoch)

writer.close()

总结

  • 核心步骤:安装 → 创建 SummaryWriter → 记录数据 → 启动服务。
  • 常用方法add_scalar(标量)、add_graph(模型结构)、add_image(图像)、add_histogram(权重分布)。
  • 进阶功能:嵌入投影、超参数对比、PR曲线等。

掌握这些操作,你可以轻松实现训练过程的可视化与深度分析!