从代码学习深度学习 - GRU PyTorch版

发布于:2025-04-05 ⋅ 阅读:(9) ⋅ 点赞:(0)


前言

在深度学习领域,循环神经网络(RNN)及其变种如GRU(Gated Recurrent Unit,门控循环单元)在处理序列数据时表现出色。相比传统RNN,GRU通过更新门(Update Gate)和重置门(Reset Gate)简化了结构,同时保持了对长期依赖关系的建模能力。本篇博客将通过PyTorch实现一个基于GRU的文本生成模型,结合《The Time Machine》数据集,逐步解析代码实现的全过程。从数据预处理到模型训练,再到结果可视化,我们将深入探讨每个模块的功能,并展示完整的代码实现。


一、GRU模型介绍

GRU(Gated Recurrent Unit,门控循环单元)是循环神经网络(RNN)的一种改进变种,由Kyunghyun Cho等人在2014年提出。它旨在解决传统RNN在处理长序列时面临的梯度消失问题,同时通过更简洁的结构提升计算效率。相比LSTM(长短期记忆网络),GRU减少了一个门控单元,使用更新门(Update Gate)和重置门(Reset Gate)来控制信息的流动,从而在保持性能的同时降低参数量。

1.1 GRU的核心机制

在这里插入图片描述

GRU的工作原理基于两个关键的门控单元:

  1. 更新门(Update Gate, z t z_t zt
    更新门决定当前时间步的隐藏状态在多大程度上保留上一时间步的隐藏状态,以及接受多少新输入的信息。其计算公式为:
    z t = σ ( W z ⋅ [ h t − 1 , x t ] + b z ) z_t = \sigma(W_z \cdot [h_{t-1}, x_t] + b_z) zt=σ(Wz[ht1,xt]+bz)
    其中, σ \sigma σ是sigmoid激活函数, h t − 1 h_{t-1} ht1 是上一时间步的隐藏状态, x t x_t xt 是当前输入, W z W_z Wz b z b_z bz 是可训练的参数。

  2. 重置门(Reset Gate, r t r_t rt
    重置门控制前一时间步的隐藏状态在多大程度上影响当前候选隐藏状态的计算。其计算公式为:
    r t = σ ( W r ⋅ [ h t − 1 , x t ] + b r ) r_t = \sigma(W_r \cdot [h_{t-1}, x_t] + b_r) rt=σ(Wr[ht1,xt]+br)

基于这两个门,GRU计算候选隐藏状态和新隐藏状态:

  • 候选隐藏状态( h ~ t \tilde{h}_t h~t
    h ~ t = tanh ⁡ ( W h ⋅ [ r t ⊙ h t − 1 , x t ] + b h ) \tilde{h}_t = \tanh(W_h \cdot [r_t \odot h_{t-1}, x_t] + b_h) h~t=tanh(Wh[rt