从代码学习深度学习 - 深度循环神经网络 PyTorch 版

发布于:2025-04-06 ⋅ 阅读:(20) ⋅ 点赞:(0)


前言

深度学习中的循环神经网络(RNN)因其在处理序列数据(如文本、时间序列等)方面的强大能力而备受关注。随着模型复杂度的增加,单层 RNN 的局限性逐渐显现,深度循环神经网络(Deep RNN)应运而生。深度 RNN 通过堆叠多层 RNN 单元,能够捕捉更复杂的序列模式,在自然语言处理、语音识别等领域展现出卓越的性能。

本文将基于 PyTorch 实现一个深度循环神经网络,并以《时间机器》数据集为例,展示如何从数据加载到模型训练的全过程。我们将深入剖析代码,结合理论知识,帮助读者从实践中理解深度 RNN 的工作原理。


一、深度循环神经网络介绍

在这里插入图片描述

循环神经网络(RNN)是一种专门处理序列数据的神经网络,其核心在于通过隐藏状态在时间步之间传递信息。然而,单层 RNN 的表达能力有限,尤其是在面对长序列或复杂依赖关系时,容易出现梯度消失或爆炸的问题。深度循环神经网络通过堆叠多层 RNN 单元,增强了模型的层次结构,使其能够学习更深层次的特征和依赖关系。

在深度 RNN 中,每一层的输出会作为下一层的输入,隐藏状态在时间维度上逐层传递。常见的深度 RNN 变体包括堆叠的 LSTM(长短期记忆网络)或 GRU(门控循环单元),这些改进型单元通过门控机制有效缓解传统 RNN 的梯度问题。

二、数据准备

我们以《时间机器》文本数据集为例,展示如何加载和预处理数据。以下是相关代码:

import random
import re
import torch
from collections import Counter

def read_time_machine():
    """将时间机器数据集加载到文本行的列表中"""
    with open('timemachine.txt', 'r') as f:
        lines = f.readlines()
    return [re.sub('[^A-Za-z]+', ' ', line).strip().lower() for line in lines]

def tokenize(lines, token='word'):
    """将文本行拆分为单词或字符词元"""
    if token == 'word':
        return [line.split() for line in lines]
    elif token == 'char':
        return [list(line) for line in lines]
    else:
        print(f'错误:未知词元类型:{
     token}')

def count_corpus(tokens):
    """统计词元的频率"""
    if not tokens:
        return Counter()
    if isinstance(tokens[0], list):
        flattened_tokens = [token for sublist in tokens for token in sublist]
    else:
        flattened_tokens = tokens
    return Counter(flattened_tokens)

class Vocab:
    """文本词表类,用于管理词元及其索引的映射关系"""
    def __init__(self, tokens=None, min_freq=0, reserved_tokens=None):
        self.tokens = tokens if tokens is not None else []
        self.reserved_tokens = reserved_tokens if reserved_tokens is not None else []
        counter = self._count_corpus(self.tokens)
        self._token_freqs = sorted(counter.items(), key=lambda x: x[1], reverse=True)
        self.idx_to_token = ['<unk>'] + self.reserved_tokens
        self.token_to_idx = {
   token: idx for idx, token in enumerate(self.idx_to_token)}
        for token, freq in self._token_freqs:
            if freq < min_freq:
                break
            if token not in self.token_to_idx:
                self.idx_to_token.append(token)
                self.token_to_idx[token] = len(self.idx_to_token) - 1

    @staticmethod
    def _count_corpus(tokens):
        if not tokens:
            return Counter()
        if isinstance(tokens[0], list):
            tokens = [token for sublist in tokens for token in sublist]
        return Counter(tokens)

    def __len__(self):
        return len(self.idx_to_token