循环神经网络(四):GRU

发布于:2025-09-11 ⋅ 阅读:(25) ⋅ 点赞:(0)

循环神经网络(四):GRU

GRU 原理解析

什么是 GRU?

GRU(Gated Recurrent Unit,门控循环单元)是循环神经网络(RNN)的一种高效变体。与传统RNN相比,GRU通过引入更新门(update gate)重置门(reset gate) 两种门控机制,显著提升了模型处理序列数据的能力。

核心机制与优势

GRU通过双门控系统实现对信息流的精准控制:

  • 选择性记忆:智能决定哪些信息需要保留或丢弃
  • 梯度管理:有效缓解了传统RNN训练中的梯度消失或爆炸问题
  • 关键特征提取:能够自动学习并保留序列中最具代表性的特征数据

这种设计使GRU在保持结构简洁性的同时,能够捕获序列中的长距离依赖关系,在多种序列建模任务中表现出色。

根据这幅图来讲:
在这里插入图片描述

1、更新门

紫色+蓝色流程:zt = sigmoid(Wz xt + Uz ht−1)

2、重置门

紫色+蓝色流程:rt = sigmoid(Wr xt+Ur ht−1)

注意: 虽然算法和更新门相同,但是两者使用的权重是不同的

3、计算候选隐藏状态

紫色+蓝色流程:ht` = tanh(U( rt ⊙ ht−1 ) + W xt)

4、更新隐藏状态

红色+深红色流程:ht = zt ⊙ ht−1 + ( 1 − zt ) ⊙ ht`

到这儿其实和卷积中的残差链接有些类似,也是连接一个原始输入来加深记忆。

两个门是一个 0~1 的系数值,用加权的方式来更新隐藏状态,那么隐藏状态的数据大小不会成指数膨胀

zt权重值,可以理解为上一层隐藏层状态对于这一层的重要程度,这样就解决了梯度消失以及梯度爆炸的问题。

代码实现:

from torch import nn


class GRUCell(nn.Module):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.hidden_size = hidden_size

        self.sigmoid = nn.Sigmoid()
        self.tanh = nn.Tanh()

        # 更新门参数
        self.Wz = nn.Parameter(torch.randn(input_size, hidden_size), requires_grad=True)
        self.Uz = nn.Parameter(torch.randn(hidden_size, hidden_size), requires_grad=True)
        # 重置门参数
        self.Wr = nn.Parameter(torch.randn(input_size, hidden_size), requires_grad=True)
        self.Ur = nn.Parameter(torch.randn(hidden_size, hidden_size), requires_grad=True)
        # 计算候选隐藏状态的参数
        self.U = nn.Parameter(torch.randn(hidden_size, hidden_size), requires_grad=True)
        self.W = nn.Parameter(torch.randn(input_size, hidden_size), requires_grad=True)

    def forward(self, x, h):
        # 1. 更新门
        zt = self.sigmoid(x @ self.Wz + h @ self.Uz)
        # 2. 重置门
        rt = self.sigmoid(x @ self.Wr + h @ self.Ur)
        # 3. 候选隐藏状态
        _h = self.tanh((h * rt) @ self.U + x @ self.W)
        # 4. 更新隐藏状态
        h = h * zt + _h * (1 - zt)
        return h


class GRU(nn.Module):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.hidden_size = hidden_size
        self.cell = GRUCell(input_size, hidden_size)
        self.fc_out = nn.Linear(hidden_size, hidden_size)

    # x (N, L, input_size)
    def forward(self, x, h=None):
        N, L, input_size = x.shape
        # 初始化 h
        if h is None:
            h = torch.zeros(N, self.hidden_size)

        outputs = []

        for i in range(L):
            # 更新隐藏状态
            h = self.cell(x[:, i], h)
            # 输出
            out = self.fc_out(h)
            outputs.append(out)

        # 堆叠
        # (N, L, hidden_size)
        outputs = torch.stack(outputs, dim=1)

        return outputs, h


if __name__ == '__main__':
    import torch

    x = torch.rand(6, 4, 10)
    model = GRU(10, 20)

    y, h = model(x)
    print(y.shape)
    print(h.shape)

网站公告

今日签到

点亮在社区的每一天
去签到