循环神经网络(四):GRU
GRU 原理解析
什么是 GRU?
GRU(Gated Recurrent Unit,门控循环单元)是循环神经网络(RNN)的一种高效变体。与传统RNN相比,GRU通过引入更新门(update gate)和重置门(reset gate) 两种门控机制,显著提升了模型处理序列数据的能力。
核心机制与优势
GRU通过双门控系统实现对信息流的精准控制:
- 选择性记忆:智能决定哪些信息需要保留或丢弃
- 梯度管理:有效缓解了传统RNN训练中的梯度消失或爆炸问题
- 关键特征提取:能够自动学习并保留序列中最具代表性的特征数据
这种设计使GRU在保持结构简洁性的同时,能够捕获序列中的长距离依赖关系,在多种序列建模任务中表现出色。
根据这幅图来讲:
1、更新门
紫色+蓝色流程:zt = sigmoid(Wz xt + Uz ht−1)
2、重置门
紫色+蓝色流程:rt = sigmoid(Wr xt+Ur ht−1)
注意: 虽然算法和更新门相同,但是两者使用的权重是不同的
3、计算候选隐藏状态
紫色+蓝色流程:ht` = tanh(U( rt ⊙ ht−1 ) + W xt)
4、更新隐藏状态
红色+深红色流程:ht = zt ⊙ ht−1 + ( 1 − zt ) ⊙ ht`
到这儿其实和卷积中的残差链接有些类似,也是连接一个原始输入来加深记忆。
两个门是一个 0~1 的系数值,用加权的方式来更新隐藏状态,那么隐藏状态的数据大小不会成指数膨胀
zt权重值,可以理解为上一层隐藏层状态对于这一层的重要程度,这样就解决了梯度消失以及梯度爆炸的问题。
代码实现:
from torch import nn
class GRUCell(nn.Module):
def __init__(self, input_size, hidden_size):
super().__init__()
self.hidden_size = hidden_size
self.sigmoid = nn.Sigmoid()
self.tanh = nn.Tanh()
# 更新门参数
self.Wz = nn.Parameter(torch.randn(input_size, hidden_size), requires_grad=True)
self.Uz = nn.Parameter(torch.randn(hidden_size, hidden_size), requires_grad=True)
# 重置门参数
self.Wr = nn.Parameter(torch.randn(input_size, hidden_size), requires_grad=True)
self.Ur = nn.Parameter(torch.randn(hidden_size, hidden_size), requires_grad=True)
# 计算候选隐藏状态的参数
self.U = nn.Parameter(torch.randn(hidden_size, hidden_size), requires_grad=True)
self.W = nn.Parameter(torch.randn(input_size, hidden_size), requires_grad=True)
def forward(self, x, h):
# 1. 更新门
zt = self.sigmoid(x @ self.Wz + h @ self.Uz)
# 2. 重置门
rt = self.sigmoid(x @ self.Wr + h @ self.Ur)
# 3. 候选隐藏状态
_h = self.tanh((h * rt) @ self.U + x @ self.W)
# 4. 更新隐藏状态
h = h * zt + _h * (1 - zt)
return h
class GRU(nn.Module):
def __init__(self, input_size, hidden_size):
super().__init__()
self.hidden_size = hidden_size
self.cell = GRUCell(input_size, hidden_size)
self.fc_out = nn.Linear(hidden_size, hidden_size)
# x (N, L, input_size)
def forward(self, x, h=None):
N, L, input_size = x.shape
# 初始化 h
if h is None:
h = torch.zeros(N, self.hidden_size)
outputs = []
for i in range(L):
# 更新隐藏状态
h = self.cell(x[:, i], h)
# 输出
out = self.fc_out(h)
outputs.append(out)
# 堆叠
# (N, L, hidden_size)
outputs = torch.stack(outputs, dim=1)
return outputs, h
if __name__ == '__main__':
import torch
x = torch.rand(6, 4, 10)
model = GRU(10, 20)
y, h = model(x)
print(y.shape)
print(h.shape)