PyTorch深度学习总结
第六章 PyTorch中张量(Tensor)微分操作
文章目录
前言
上文介绍了PyTorch中张量(Tensor)
的计算
操作,本文将介绍张量
的微分
(torch.autograd
)操作。
一、torch.autograd模块
torch.autograd
是 PyTorch 中用于自动求导
的核心工具包,它提供了自动计算张量梯度的功能。训练模型通常需要计算损失函数关于模型参数的梯度,以便使用优化算法更新参数。
基本原理
torch.autograd
通过构建计算图(computational graph)
来跟踪张量上的所有操作。
计算图
是一个有向无环图(DAG)
,其中节点表示张量,边表示操作。当你对一个张量进行操作时,torch.autograd 会记录这些操作,并构建相应的计算图。
在需要计算梯度时,torch.autograd
会使用反向传播算法(backpropagation)
沿着计算图反向传播,从最终的输出张量开始,逐步计算每个操作的梯度,并累积到需要求梯度的张量上。
二、主要功能和使用方法
1. 张量的 requires_grad 属性
requires_grad
是张量的一个布尔属性,用于指定是否需要对该张量计算梯度。
如果将一个张量的requires_grad
设置为True
,torch.autograd
会跟踪该张量上的所有操作,并在需要时计算梯度。import torch # 创建一个需要计算梯度的张量 x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True) # 进行一些操作 y = x * 2 # 此时 y 也会自动设置 requires_grad 为 True print(y.requires_grad) # 输出: True
2. backward() 方法
backward()
方法用于计算梯度,只能处理标量输出。
当调用backward()
时,torch.autograd
会从调用该方法的张量开始,沿着计算图反向传播
,计算所有requires_grad
为True
的张量的梯度。import torch #生成可以计算梯度的张量 x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True) #对张量进行操作 y = x * 2 # 对 y 求和得到标量 y_sum = y.sum() # 计算梯度 y_sum.backward() # 查看 x 的梯度 print(x.grad) # 输出: tensor([2., 2., 2.])
注意:
在 PyTorch 中backward()
方法默认只能处理标量输出。这是因为梯度本质上是损失函数关于模型参数的导数,而导数是一个标量函数相对于另一个标量的变化率。
下方是一段错误代码:
import torch #生成可以计算梯度的张量 x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True) #对张量进行操作 y = x * 2 # 计算梯度 y.backward() # 查看 x 的梯度 print(x.grad)
输出结果为:
RuntimeError: grad can be implicitly created only for scalar outputs
解决方法:
对非标量输出进行聚合操作(如求和、求均值等),将其转换为标量,再调用backward()
方法。示例:
import torch #生成可以计算梯度的张量 x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True) #对张量进行操作 y = torch.sum(x * 2) # 计算梯度 y.backward() # 查看 x 的梯度 print(x.grad) # 输出: tensor([2., 2., 2.])
在调用
backward()
之前,使用sum()
或mean()
函数对输出进行聚合。
3. torch.no_grad() 上下文管理器
在某些情况下,你可能不需要计算梯度,例如在模型推理阶段。
torch.no_grad()
上下文管理器可以临时禁用梯度计算,从而节省内存和计算资源。import torch x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True) with torch.no_grad(): y = x * 2 # 在 no_grad 上下文管理器中,y 的 requires_grad 为 False print(y.requires_grad) # 输出: False
三、函数总结
函数 | 描述 |
---|---|
torch.tensor([],requires_grad=True) |
允许对该张量计算梯度 |
backward() |
计算所有允许计算梯度张量的梯度 |
torch.no_grad() |
上下文管理器可以临时禁用梯度计算 |