常用组件详解(二):torchsummary

发布于:2024-07-03 ⋅ 阅读:(16) ⋅ 点赞:(0)


一、基本使用

  torchsummary库是一个好用的模型可视化工具,用于帮助开发者把握每个网络层级的细节,包括其中的连接和维度。使用方法:

from torchsummary import summary

库中仅有一个函数:

summary(model, input_size, batch_size=-1, device="cuda"):
  • model:模型对象。
  • input_size:输入数据的格式,使用(C,H,W)格式。
  • batch_size:批数据的数量。
  • device:使用的设备。

  以自定义的LeNet网络模型为例:

import torch
from torch import nn
from torchsummary import summary

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        # 手写数字图片大小为32*32,故需填充2个像素
        self.model = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, padding=2),
            nn.Sigmoid(),
            nn.AvgPool2d(kernel_size=2, stride=2),
            nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5),
            nn.AvgPool2d(kernel_size=2, stride=2),
            nn.Flatten(),
            nn.Linear(in_features=16 * 5 * 5, out_features=120),
            nn.Linear(in_features=120, out_features=84),
            nn.Linear(in_features=84, out_features=10),
        )

    def forward(self, x):
        return self.model(x)

myLeNet = LeNet().to(device)
print(summary(myLeNet, input_size=(1, 28, 28), batch_size=64, device='cuda'))

在这里插入图片描述

二、常见指标

2.1Input size

  Input size表示输入数据的大小。在上述例子中,batch_size=64,每张图片大小为(1,28,28),而Pytorch默认使用float32(双精度浮点数)占4字节,则每个batch所用内存大小为:
64 x 1 x 28 x 28 x 4 = 200 , 704 ( B y t e s ) 64x1x28x28x4=200,704(Bytes) 64x1x28x28x4=200,704Bytes
转化为以MB为单位:
200 , 704 / 102 4 2 ( B y t e s ) = 0.19140625 ( B y t e s ) 200,704/1024^2(Bytes)=0.19140625(Bytes) 200,704/10242Bytes=0.19140625Bytes
约等于0.19MB。

2.2Forward/backward pass size

https://blog.csdn.net/weixin_43589323/article/details/137105988?ops_request_misc=&request_id=&biz_id=102&utm_term=torchsummary&utm_medium=distribute.pc_search_result.none-task-blog-2allsobaiduweb~default-5-137105988.142v100pc_search_result_base2&spm=1018.2226.3001.4187


网站公告

今日签到

点亮在社区的每一天
去签到