10--门控循环神经网络&长短期记忆网络LSTM

发布于:2022-12-20 ⋅ 阅读:(828) ⋅ 点赞:(0)

门控循环神经网络

        在循环神经网络中,矩阵连乘导致的梯度消失问题和梯度爆炸。虽然存在梯度裁剪机制来缓解梯度爆炸,但无法解决梯度消失的问题。这些梯度异常在实践中表现为:

1,早期观测值对预测所有未来观测值具有非常重要的意义

2,一些词元没有相关的观测值

3,序列的各个部分之间存在逻辑中断

        为了解决这些问题,提出了“长短期记忆”,而门控循环神经网络就就是基于该方法提出的网络,为了解决长期记忆和反向传播中的梯度等问题。可以更好的捕捉时许数据中间隔较大的依赖关系。

        门控循环神经单元(GRU)中的隐状态结构如下所示:

         其中重置门允许我们控制“可能还想记住”的过去状态的数量,有助于捕获序列中的短期依赖关系。更新门将允许我们控制新状态中有多少个是旧状态的副本,有助于长期依赖关系。

        重置门\mathbf{R}_t和更新门\mathbf{Z}_t的更新公式如下,这里使用sigmoid函数来控制最后的值为(0,1)的向量:

\mathbf{R}_t = \sigma(\mathbf{X}_t \mathbf{W}_{xr} + \mathbf{H}_{t-1} \mathbf{W}_{hr} + \mathbf{b}_r)

\mathbf{Z}_t = \sigma(\mathbf{X}_t \mathbf{W}_{xz} + \mathbf{H}_{t-1} \mathbf{W}_{hz} + \mathbf{b}_z),

        候选隐状态\tilde{\mathbf{H}}_t在t时刻的值为:\tilde{\mathbf{H}}_t = \tanh(\mathbf{X}_t \mathbf{W}_{xh} + \left(\mathbf{R}_t \odot \mathbf{H}_{t-1}\right) \mathbf{W}_{hh} + \mathbf{b}_h), 这里使用tanh非线性激活函数来确保候选隐状态中的值保持在区间(−1,1)中。

         如果不看\mathbf{R}_t则跟普通RNN中隐藏状态的更新是一致的。由于\mathbf{R}_t是(0,1)之间的向量,当为0时(极端状态),当前候选隐状态就只与当前输入有关,即“忘记”之前所有的记忆。使用\mathbf{R}_t来控制以往状态对当前的影响。

        当前输出隐状态\mathbf{H}_t由更新门\mathbf{Z}_t和候选隐状态\tilde{\mathbf{H}}_t共同决定,更新公式如下: 

\mathbf{H}_t = \mathbf{Z}_t \odot \mathbf{H}_{t-1} + (1 - \mathbf{Z}_t) \odot \tilde{\mathbf{H}}_t.

        当 \mathbf{Z}_t接近1时,模型就倾向只保留旧状态,来自\mathbf{X}_t的当前输入信息就基本被忽略,从而有效地跳过了依赖链条中的时间步t

 代码实现:

import torch
from torch import nn
from d2l import torch as d2l

batch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)

vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1

num_inputs = vocab_size
gru_layer = nn.GRU(num_inputs, num_hiddens)
model = d2l.RNNModel(gru_layer, len(vocab))
model = model.to(device)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

LSTM 

        GRU是LSTM的简洁版本,LSTM中的隐藏状态不止包括隐状态Ht还包括记忆元Ct,具体架构如下所示:

        各变量的更新公式为:

         代码实现:

import torch
from torch import nn
from d2l import torch as d2l

batch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)

vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1

num_inputs = vocab_size
lstm_layer = nn.LSTM(num_inputs, num_hiddens)
model = d2l.RNNModel(lstm_layer, len(vocab))
model = model.to(device)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)