tensorboard in pytorch

发布于:2025-05-11 ⋅ 阅读:(24) ⋅ 点赞:(0)
# PyTorch model and training necessities
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

# Image datasets and image manipulation
import torchvision
import torchvision.transforms as transforms

# Image display
import matplotlib.pyplot as plt
import numpy as np

# PyTorch TensorBoard support
from torch.utils.tensorboard import SummaryWriter

# Gather datasets and prepare them for consumption
transform = transforms.Compose(
    [transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))])

# Store separate training and validations splits in ./data
training_set = torchvision.datasets.FashionMNIST('./data',
    download=True,
    train=True,
    transform=transform)
validation_set = torchvision.datasets.FashionMNIST('./data',
    download=True,
    train=False,
    transform=transform)

training_loader = torch.utils.data.DataLoader(training_set,
                                              batch_size=4,
                                              shuffle=True,
                                              num_workers=2)


validation_loader = torch.utils.data.DataLoader(validation_set,
                                                batch_size=4,
                                                shuffle=False,
                                                num_workers=2)

# Class labels
classes = ('T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
        'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle Boot')

这段代码的作用是加载并预处理 FashionMNIST 数据集,FashionMNIST 是一个包含 28x28 像素灰度图像的服装数据集,用于图像分类任务。我们来逐步解释代码:

1. 数据预处理(Transformations)

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5,), (0.5,))])
  • transforms.Compose:是一个工具,可以将多个操作按顺序组合在一起,对数据进行一系列处理。

    • transforms.ToTensor():这一步将图像转换成 PyTorch 中可以使用的 张量(Tensor) 格式。并且把像素值从 [0, 255] 范围变成 [0, 1] 之间。

    • transforms.Normalize((0.5,), (0.5,)):这一步对图像做归一化处理,把像素值的均值调整为 0.5,标准差调整为 0.5。这样可以帮助神经网络更快地收敛,提高训练效果。

2. 加载数据集(Dataset)

training_set = torchvision.datasets.FashionMNIST('./data',
    download=True,
    train=True,
    transform=transform)
validation_set = torchvision.datasets.FashionMNIST('./data',
    download=True,
    train=False,
    transform=transform)
  • 这里使用了 torchvision.datasets.FashionMNIST 来加载 FashionMNIST 数据集。

    • ./data:指定了数据集保存的目录。如果没有下载,程序会自动下载到这个目录。

    • download=True:表示如果本地没有数据集,就自动从网上下载。

    • train=True:表示加载的是 训练集

    • train=False:表示加载的是 测试集(验证集)

    • transform=transform:表示在加载数据时,应用之前定义的预处理操作(转为张量并归一化)。

3. 数据加载器(DataLoader)

training_loader = torch.utils.data.DataLoader(training_set,
                                              batch_size=4,
                                              shuffle=True,
                                              num_workers=2)
validation_loader = torch.utils.data.DataLoader(validation_set,
                                                batch_size=4,
                                                shuffle=False,
                                                num_workers=2)
  • DataLoader 是一个用来批量加载数据的工具,这样可以提高训练时的效率。

    • batch_size=4:每次加载 4 张图片,这些图片会组成一个批次(batch)。

    • shuffle=True:表示在训练时,数据会被随机打乱,这样有助于防止模型记住数据的顺序,提高训练效果。

    • shuffle=False:在验证时,不需要打乱数据,保持数据顺序。

    • num_workers=2:表示有 2 个子进程 用来并行加载数据,这样可以加快数据读取的速度。

4. 类标签(Classes)

classes = ('T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
           'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle Boot')
  • 这是一个包含 10 个类别 的列表,每个类别对应着一种服装类型。每张图片的标签就表示它是哪一类服装:

    • 0 代表 T恤/上衣

    • 1 代表 裤子

    • 2 代表 毛衣

    • 3 代表 裙子

    • 4 代表 外套

    • 5 代表 凉鞋

    • 6 代表 衬衫

    • 7 代表 运动鞋

    • 8 代表

    • 9 代表 高筒靴

总结

这段代码的主要目的是:

  1. 加载 FashionMNIST 数据集,并将其分为训练集和验证集。

  2. 对数据进行 预处理,将图像转换为 PyTorch 张量,并进行标准化。

  3. 使用 DataLoader 批量加载数据,方便训练时一次性输入多个图像。

  4. 定义了图像的 类别标签,这些标签用于后续训练和评估模型。

简而言之,这段代码准备了数据,使得我们可以直接拿来训练一个图像分类模型,自动识别不同类型的服装。

# Helper function for inline image display
def matplotlib_imshow(img, one_channel=False):
    if one_channel:
        img = img.mean(dim=0)
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    if one_channel:
        plt.imshow(npimg, cmap="Greys")
    else:
        plt.imshow(np.transpose(npimg, (1, 2, 0)))
# Extract a batch of 4 images
dataiter = iter(training_loader)
images, labels = next(dataiter)

# Create a grid from the images and show them
img_grid = torchvision.utils.make_grid(images)
matplotlib_imshow(img_grid, one_channel=True)

# Default log_dir argument is "runs" - but it's good to be specific
# torch.utils.tensorboard.SummaryWriter is imported above
writer = SummaryWriter('runs/fashion_mnist_experiment_1')

# Write image data to TensorBoard log dir
writer.add_image('Four Fashion-MNIST Images', img_grid)
writer.flush()

# To view, start TensorBoard on the command line with:
#   tensorboard --logdir=runs
# ...and open a browser tab to http://localhost:6006/

这段代码是用来从训练集中取出一批图像,并将它们可视化显示出来的,适用于调试或初步查看训练数据。下面我们用通俗的中文逐行解释:


✅ 函数:显示图片的辅助函数

def matplotlib_imshow(img, one_channel=False):

这个函数叫做 matplotlib_imshow,是一个用于显示图像的辅助工具。你传入一张图,它就会用 matplotlib 来画出来。


    if one_channel:
        img = img.mean(dim=0)

如果你传入的 one_channel=True,说明你希望把图像变成灰度图(单通道),所以这一步就会把图像的颜色通道取平均,变成灰度图。


    img = img / 2 + 0.5     # unnormalize

这一步是“反归一化”,因为之前图像被归一化成了 [-1, 1] 之间(即 Normalize((0.5,), (0.5,))),现在要显示出来,就把它还原成 [0, 1] 范围的像素值:

原始像素 = 归一化像素 × 0.5 + 0.5


    npimg = img.numpy()

将图像从 PyTorch 的 Tensor 转换成 NumPy 数组,因为 matplotlib 是用 NumPy 数据绘图的。


    if one_channel:
        plt.imshow(npimg, cmap="Greys")

如果是灰度图,就用灰色的色图(colormap)来显示图像。


    else:
        plt.imshow(np.transpose(npimg, (1, 2, 0)))

如果是彩色图,需要把通道维度(channel)从 [C, H, W] 转换成 [H, W, C],因为 matplotlib 要求的是这种格式(高度、宽度、颜色通道)。


✅ 提取一批图像(batch)

dataiter = iter(training_loader)
images, labels = dataiter.next()

这两行代码的意思是:

  • 从训练数据加载器 training_loader 中取出一个批次的数据(batch size = 4),并分为图像和标签两部分。

  • images 就是形状为 [4, 1, 28, 28] 的张量,表示 4 张 28×28 的灰度图像。


✅ 创建图像网格并显示

img_grid = torchvision.utils.make_grid(images)

将这 4 张图像拼成一个图像网格(横着排列),返回一个可以显示的图像张量。


matplotlib_imshow(img_grid, one_channel=True)

调用刚才定义的函数 matplotlib_imshow(),将这个图像网格以灰度图的形式显示出来。


📌 总结一句话:

这段代码从训练集中随机取出 4 张图像,拼成一张图像网格,并使用 matplotlib 显示出来,方便我们快速了解训练数据的样子。

# Select a random subset of data and corresponding labels
def select_n_random(data, labels, n=100):
    assert len(data) == len(labels)

    perm = torch.randperm(len(data))
    return data[perm][:n], labels[perm][:n]

# Extract a random subset of data
images, labels = select_n_random(training_set.data, training_set.targets)

# get the class labels for each image
class_labels = [classes[label] for label in labels]

# log embeddings
features = images.view(-1, 28 * 28)
writer.add_embedding(features,
                    metadata=class_labels,
                    label_img=images.unsqueeze(1))
writer.flush()
writer.close()

好的,我来用通俗易懂的中文一步步详细解释这段代码的意思:


1. 选择一小部分数据

def select_n_random(data, labels, n=100):
    assert len(data) == len(labels)

    perm = torch.randperm(len(data))
    return data[perm][:n], labels[perm][:n]

这段代码定义了一个函数 select_n_random,它的作用是:

  • 随机从一大堆数据(data)和对应的标签(labels)里面,挑选出 n 个样本(默认是 100 个)。

  • torch.randperm(len(data)) 是生成一个随机排列,比如数据有 10000张图片,它就打乱这10000个顺序,然后从中取前 n 个。

  • 注意:assert len(data) == len(labels) 是确保图片和标签数量一一对应,不然就出错。

简单理解: 👉 就像从一堆扑克牌中随机抽出 100 张牌一样。


2. 提取随机子集

images, labels = select_n_random(training_set.data, training_set.targets)

这一行用上面定义的函数,从你的训练集 (training_set) 中随机挑了 100 张图片及对应的标签。

  • images 保存了随机抽到的图片

  • labels 保存了这些图片对应的正确类别


3. 拿到每张图片的文字标签

class_labels = [classes[label] for label in labels]
  • 这里把数字标签(比如0、1、2)转换成了文字标签(比如 "T-shirt", "Trouser")。

  • classes 是一个列表,比如:

    classes = ['T-shirt', 'Trouser', 'Pullover', ...]
    

简单理解: 👉 就是把数字变成更好懂的中文/英文类别名。


4. 准备好数据做可视化(特征降维)

features = images.view(-1, 28 * 28)
  • 把每张图片(原来是 28×28 的二维小图片)拉平成一行 784 个数字,因为后面要把这些数字送给 TensorBoard 画图。

简单理解: 👉 把小方块图片拉成长长的一条数据线。


5. 写入到 TensorBoard 的 Embedding

writer.add_embedding(features,
                    metadata=class_labels,
                    label_img=images.unsqueeze(1))
  • 把这些图片的特征对应的文字标签、以及原始图片,全部写入到 TensorBoard。

  • metadata=class_labels:就是告诉 TensorBoard,这个点对应的是什么类别。

  • label_img=images.unsqueeze(1):把图片加一个通道数(变成 1 通道),符合 TensorBoard 要求的格式。

简单理解: 👉 把这些图片的数据、名字和图片本身,全部打包进 TensorBoard,方便后面可视化查看。


6. 刷新并关闭文件

writer.flush()
writer.close()
  • flush() 是确保所有数据被保存到磁盘,不然可能还有东西留在内存里没写完。

  • close() 是关闭日志文件,结束写入。


总结一下通俗版流程:

  1. 随机选 100张训练图片。

  2. 拿到图片的文字类别

  3. 拉平成一行数据

  4. 写到 TensorBoard,方便后面用图形界面直观地看。

最终你可以在 TensorBoard 上看到:

  • 每一张图片在空间中的分布(像小点点一样)

  • 点点上可以标注类别名字,甚至直接看到小图片!


网站公告

今日签到

点亮在社区的每一天
去签到