[语言模型训练]基于 PyTorch 的双向 LSTM 文本分类器实现:基于旅店的评论分类语言模型

发布于:2025-07-25 ⋅ 阅读:(15) ⋅ 点赞:(0)

在自然语言处理领域,文本分类是一项基础且重要的任务,广泛应用于情感分析、垃圾邮件识别、新闻分类等场景。本文将详细介绍如何使用 PyTorch 实现一个基于双向 LSTM 的文本分类模型,从数据预处理到模型训练、评估和预测,一步一步带你完成整个流程。

【本博客的模型主要是对旅店的评论进行好评和差评的二分类】

【效果展示】

 【后附有源码,复制粘贴即可运行,自动创建目录并给出数据集提示,非常方便!】

【本猿定期无偿分享学习成果,欢迎关注一起学习!!!】

一.环境准备

1.必要包安装

首先,确保你已经安装了以下必要的 Python 库:

pip install torch numpy pandas jieba scikit-learn
  • torch:PyTorch 深度学习框架

  • numpy:数值计算库

  • jieba:中文分词工具

  • scikit-learn:提供数据分割等工具

2.项目结构设计

在开始编写代码前,我们先规划一下项目的目录结构:

 

 

二.定义路径和基础配置

接下来,我们定义数据文件路径、模型保存路径以及一些基础配置参数:

# 数据文件路径定义与基础配置
csv_path = os.path.relpath(os.path.join(os.path.dirname(__file__), './data', 'data.csv'))
word_index_path = os.path.relpath(os.path.join(os.path.dirname(__file__), './data', 'word_index.json'))
index_path = os.path.relpath(os.path.join(os.path.dirname(__file__), './data', 'index.jsonl'))
model_path = os.path.relpath(os.path.join(os.path.dirname(__file__), './weights', 'model.pth'))
stopwords_path = os.path.relpath(os.path.join(os.path.dirname(__file__), './data', 'stopwords.txt'))  # 停用词文件路径

# 文本处理配置
punk = ',。!?,.:;!?()[]{}"\' '  # 标点符号集合
line_length_threshold = 100  # 文本序列最大长度

# 设备配置,自动选择GPU或CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

 

三.目录和文件检查工具

为了确保程序能够正常运行,我们需要创建必要的目录并检查数据文件是否存在:

# 检查并创建必要的目录
def create_necessary_directories():
    """创建程序运行所需的所有目录"""
    directories = [
        os.path.dirname(csv_path),
        os.path.dirname(model_path),
        os.path.dirname(stopwords_path)
    ]
    
    for dir_path in directories:
        if not os.path.exists(dir_path):
            os.makedirs(dir_path, exist_ok=True)
            print(f"已创建目录: {dir_path}")


# 检查CSV文件是否存在
def check_csv_file():
    """检查CSV文件是否存在,如果不存在则提示用户添加"""
    if not os.path.exists(csv_path):
        print(f"\n错误: 未在 {csv_path} 找到数据文件")
        print(f"请将 data.csv 文件添加到 {os.path.dirname(csv_path)} 目录下")
        print("CSV文件格式应为: 第一列为标签(0或1),第二列为文本内容")
        return False
    return True

四.停用词处理

中文文本处理中,停用词过滤是一个重要步骤,可以去除对分类贡献不大的常用词:

# 加载停用词
def load_stopwords():
    """加载停用词表,若不存在则创建默认停用词表"""
    if not os.path.exists(stopwords_path):
        # 默认停用词(可根据需求扩展)
        default_stopwords = ["的", "了", "在", "是", "我", "有", "和", "就", "人", "都", "一", "一个", "上", "也", "很", "到", "说", "要", "去", "你", "会", "着", "没有", "看", "好", "自己", "这"]
        with open(stopwords_path, 'w', encoding='utf-8') as f:
            f.write('\n'.join(default_stopwords))
        print(f"已创建默认停用词表: {stopwords_path}")
    with open(stopwords_path, 'r', encoding='utf-8') as f:
        return set(f.read().splitlines())

 五.数据预处理

文本数据需要转换为模型可以处理的数值形式,这一步包括构建词索引和将文本转换为索引序列:

# 构建词到索引的映射表(含停用词过滤)
def build_word_index():
    index = 1
    word_index = {}
    stopwords = load_stopwords()

    with open(csv_path, 'r', encoding='utf-8-sig') as f, \
            open(word_index_path, 'w', encoding='utf-8-sig') as f1:

        reader = csv.reader(f)
        for row in reader:
            if len(row) < 2:
                continue
            text = row[1]
            words = jieba.cut(text)  # 使用jieba进行中文分词

            for word in words:
                # 过滤停用词、标点和空字符
                if word.strip() and word not in punk and word not in stopwords and word not in word_index:
                    word_index[word] = index
                    index += 1

        json.dump(word_index, f1, ensure_ascii=False)
    print(f"已构建词索引表: {word_index_path}")
    return word_index


# 加载词索引表
def load_word_index():
    if not os.path.exists(word_index_path):
        return build_word_index()
    with open(word_index_path, 'r', encoding='utf-8-sig') as f:
        return json.load(f)


# 将文本转换为索引序列(含停用词过滤)
def convert_text_to_index():
    word_index = load_word_index()
    stopwords = load_stopwords()

    with open(csv_path, 'r', encoding='utf-8-sig') as f, \
            open(index_path, 'w', encoding='utf-8-sig') as f1:

        reader = csv.reader(f)
        for row in reader:
            if len(row) < 2:
                continue
            label, text = row[0], row[1]
            line_indexes = []
            words = jieba.cut(text)

            for word in words:
                # 过滤停用词、标点和空字符,并检查词是否在词表中
                if word.strip() and word not in punk and word not in stopwords and word in word_index:
                    line_indexes.append(word_index[word])

            # 将索引序列和标签写入文件
            f1.write(json.dumps({"indexes": line_indexes, "label": label}, ensure_ascii=False) + '\n')
    print(f"已将文本转换为索引序列: {index_path}")

 六.加载预处理数据

将处理好的索引序列数据加载到内存,并进行长度统一处理:

def load_processed_data():
    if not os.path.exists(index_path):
        convert_text_to_index()
    labels = []
    inputs = []

    with open(index_path, 'r', encoding='utf-8-sig') as f:
        for line in f:
            data = json.loads(line)
            indexes = data["indexes"]
            label = int(data["label"])

            # 统一序列长度:超过阈值则截断,不足则补0
            if len(indexes) > line_length_threshold:
                processed = indexes[:line_length_threshold]
            else:
                processed = indexes + [0] * (line_length_threshold - len(indexes))

            inputs.append(processed)
            labels.append(label)

    print(f"已加载预处理数据,共 {len(labels)} 条记录")
    return inputs, labels

七.构建数据加载器

使用 PyTorch 的 DataLoader 来方便地处理批量数据和数据打乱:

def get_data_loaders(test_size=0.2, random_state=42, batch_size=32):
    inputs, labels = load_processed_data()
    # 分割训练集和测试集,使用分层抽样保持类别比例
    train_data, test_data, train_labels, test_labels = train_test_split(
        inputs, labels, test_size=test_size, random_state=random_state, stratify=labels
    )

    # 转换为TensorDataset
    train_dataset = TensorDataset(
        torch.tensor(train_data, dtype=torch.long),
        torch.tensor(train_labels, dtype=torch.long)
    )
    test_dataset = TensorDataset(
        torch.tensor(test_data, dtype=torch.long),
        torch.tensor(test_labels, dtype=torch.long)
    )

    # 创建数据加载器
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    print(f"数据加载完成 - 训练集: {len(train_dataset)} 条, 测试集: {len(test_dataset)} 条")
    return train_loader, test_loader

 八.定义双向 LSTM 模型

我们使用双向 LSTM 结合池化操作来捕获文本的上下文信息:

class LSTMClassifier(nn.Module):
    """基于双向LSTM+池化的文本分类模型"""

    def __init__(self, vocab_size, embedding_dim=256, hidden_dim=256,
                 output_dim=2, num_layers=2, dropout=0.5):
        super(LSTMClassifier, self).__init__()
        # 嵌入层:将词索引转换为密集向量
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
        
        # 双向LSTM层
        self.lstm = nn.LSTM(
            input_size=embedding_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True,  # 第一个维度为batch_size
            bidirectional=True,  # 双向LSTM
            dropout=dropout if num_layers > 1 else 0  # 多层时才使用dropout
        )
        
        # 全连接层:双向LSTM输出维度为hidden_dim*2,拼接平均池化和最大池化结果
        self.fc = nn.Linear(hidden_dim * 2 * 2, output_dim)  # 2(双向)*2(两种池化)
        self.dropout = nn.Dropout(dropout)  # Dropout层防止过拟合

    def forward(self, x):
        # x shape: (batch_size, seq_len)
        embedded = self.dropout(self.embedding(x))  # (batch_size, seq_len, embedding_dim)
        lstm_out, _ = self.lstm(embedded)  # (batch_size, seq_len, hidden_dim*2)
        
        # 池化操作(替代仅取最后一个时间步)
        avg_pool = torch.mean(lstm_out, dim=1)  # (batch_size, hidden_dim*2)
        max_pool, _ = torch.max(lstm_out, dim=1)  # (batch_size, hidden_dim*2)
        
        # 拼接两种池化结果
        combined = torch.cat([avg_pool, max_pool], dim=1)  # (batch_size, hidden_dim*2*2)
        
        return self.fc(self.dropout(combined))

九.模型训练与评估

实现模型的训练过程,并加入学习率调度和早停机制防止过拟合:

def train_model(epochs=50, lr=0.001, batch_size=32):
    train_loader, test_loader = get_data_loaders(batch_size=batch_size)
    vocab_size = get_vocab_size() + 1  # +1 是因为索引从1开始

    # 初始化模型、优化器和损失函数
    model = LSTMClassifier(vocab_size).to(device)
    optimizer = opt.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    
    # 学习率调度器:测试损失3轮不下降则乘以0.5
    scheduler = opt.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', patience=3, factor=0.5, verbose=True
    )

    best_test_acc = 0.0
    no_improve_epochs = 0  # 连续未提升轮数计数器
    early_stop_patience = 10  # 提前停止阈值

    print(f"开始模型训练,使用设备: {device}")
    for epoch in range(epochs):
        # 训练阶段
        model.train()
        train_loss, train_correct = 0.0, 0

        for texts, labels in train_loader:
            texts, labels = texts.to(device), labels.to(device)
            optimizer.zero_grad()  # 清零梯度
            outputs = model(texts)  # 前向传播
            loss = criterion(outputs, labels)  # 计算损失
            loss.backward()  # 反向传播
            optimizer.step()  # 更新参数

            train_loss += loss.item()
            train_correct += (outputs.argmax(1) == labels).sum().item()

        # 测试阶段
        model.eval()
        test_loss, test_correct = 0.0, 0
        with torch.no_grad():  # 关闭梯度计算
            for texts, labels in test_loader:
                texts, labels = texts.to(device), labels.to(device)
                outputs = model(texts)
                loss = criterion(outputs, labels)
                test_loss += loss.item()
                test_correct += (outputs.argmax(1) == labels).sum().item()

        # 计算指标
        train_acc = train_correct / len(train_loader.dataset)
        test_acc = test_correct / len(test_loader.dataset)
        avg_train_loss = train_loss / len(train_loader)
        avg_test_loss = test_loss / len(test_loader)

        # 打印信息
        print(f"\nEpoch {epoch + 1}/{epochs}")
        print(f"Train Loss: {avg_train_loss:.4f} | Acc: {train_acc:.4f}")
        print(f"Test Loss: {avg_test_loss:.4f} | Acc: {test_acc:.4f}")

        # 学习率调度
        scheduler.step(avg_test_loss)

        # 保存最佳模型与提前停止判断
        if test_acc > best_test_acc:
            best_test_acc = test_acc
            torch.save(model.state_dict(), model_path)
            print(f"保存最佳模型 (Test Acc: {best_test_acc:.4f}) 到 {model_path}")
            no_improve_epochs = 0  # 重置计数器
        else:
            no_improve_epochs += 1
            print(f"连续 {no_improve_epochs} 轮测试准确率未提升")
            if no_improve_epochs >= early_stop_patience:
                print(f"触发提前停止({early_stop_patience}轮未提升)")
                break

    return best_test_acc


def evaluate_model():
    _, test_loader = get_data_loaders()
    vocab_size = get_vocab_size() + 1
    model = LSTMClassifier(vocab_size).to(device)
    model.load_state_dict(torch.load(model_path))  # 加载最佳模型权重
    model.eval()

    criterion = nn.CrossEntropyLoss()
    test_loss, test_correct = 0.0, 0

    with torch.no_grad():
        for texts, labels in test_loader:
            texts, labels = texts.to(device), labels.to(device)
            outputs = model(texts)
            loss = criterion(outputs, labels)
            test_loss += loss.item()
            test_correct += (outputs.argmax(1) == labels).sum().item()

    test_acc = test_correct / len(test_loader.dataset)
    print(f"\n最终评估结果:")
    print(f"Test Loss: {test_loss / len(test_loader):.4f} | Test Acc: {test_acc:.4f}")
    return test_acc


def get_vocab_size():
    return len(load_word_index())

十.预测函数

实现对新文本的预测功能:

def predict_text(text):
    word_index = load_word_index()
    stopwords = load_stopwords()
    vocab_size = get_vocab_size() + 1

    # 文本预处理:分词、过滤停用词、转换为索引
    words = jieba.cut(text)
    indexes = []
    for word in words:
        if word.strip() and word not in punk and word not in stopwords and word in word_index:
            indexes.append(word_index[word])

    # 统一序列长度
    if len(indexes) > line_length_threshold:
        indexes = indexes[:line_length_threshold]
    else:
        indexes += [0] * (line_length_threshold - len(indexes))

    # 转换为张量并添加批次维度
    input_tensor = torch.tensor(indexes, dtype=torch.long).unsqueeze(0).to(device)
    
    # 加载模型并预测
    model = LSTMClassifier(vocab_size).to(device)
    model.load_state_dict(torch.load(model_path))
    model.eval()

    with torch.no_grad():
        output = model(input_tensor)
        prediction = torch.argmax(output, dim=1).item()

    return prediction

十一.完整源代码

 

import os
import json
import torch
import numpy as np
import csv
import jieba
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
import torch.nn as nn
import torch.optim as opt

# 数据文件路径定义与基础配置
csv_path = os.path.relpath(os.path.join(os.path.dirname(__file__), './data', 'data.csv'))
word_index_path = os.path.relpath(os.path.join(os.path.dirname(__file__), './data', 'word_index.json'))
index_path = os.path.relpath(os.path.join(os.path.dirname(__file__), './data', 'index.jsonl'))
model_path = os.path.relpath(os.path.join(os.path.dirname(__file__), './weights', 'model.pth'))
stopwords_path = os.path.relpath(os.path.join(os.path.dirname(__file__), './data', 'stopwords.txt'))  # 停用词文件路径
punk = ',。!?,.:;!?()[]{}"\' '
line_length_threshold = 100
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# 新增:检查并创建必要的目录
def create_necessary_directories():
    """创建程序运行所需的所有目录"""
    directories = [
        os.path.dirname(csv_path),
        os.path.dirname(model_path),
        os.path.dirname(stopwords_path)
    ]
    
    for dir_path in directories:
        if not os.path.exists(dir_path):
            os.makedirs(dir_path, exist_ok=True)
            print(f"已创建目录: {dir_path}")


# 新增:检查CSV文件是否存在
def check_csv_file():
    """检查CSV文件是否存在,如果不存在则提示用户添加"""
    if not os.path.exists(csv_path):
        print(f"\n错误: 未在 {csv_path} 找到数据文件")
        print(f"请将 data.csv 文件添加到 {os.path.dirname(csv_path)} 目录下")
        print("CSV文件格式应为: 第一列为标签(0或1),第二列为文本内容")
        return False
    return True


# 新增:加载停用词
def load_stopwords():
    """加载停用词表,若不存在则创建默认停用词表"""
    if not os.path.exists(stopwords_path):
        # 默认停用词(可根据需求扩展)
        default_stopwords = ["的", "了", "在", "是", "我", "有", "和", "就", "不", "人", "都", "一", "一个", "上", "也", "很", "到", "说", "要", "去", "你", "会", "着", "没有", "看", "好", "自己", "这"]
        with open(stopwords_path, 'w', encoding='utf-8') as f:
            f.write('\n'.join(default_stopwords))
        print(f"已创建默认停用词表: {stopwords_path}")
    with open(stopwords_path, 'r', encoding='utf-8') as f:
        return set(f.read().splitlines())


# 数据预处理函数
def build_word_index():
    """构建词到索引的映射表(含停用词过滤)"""
    index = 1
    word_index = {}
    stopwords = load_stopwords()

    with open(csv_path, 'r', encoding='utf-8-sig') as f, \
            open(word_index_path, 'w', encoding='utf-8-sig') as f1:

        reader = csv.reader(f)
        for row in reader:
            if len(row) < 2:
                continue
            text = row[1]
            words = jieba.cut(text)

            for word in words:
                # 新增:过滤停用词
                if word.strip() and word not in punk and word not in stopwords and word not in word_index:
                    word_index[word] = index
                    index += 1

        json.dump(word_index, f1, ensure_ascii=False)
    print(f"已构建词索引表: {word_index_path}")
    return word_index


def load_word_index():
    if not os.path.exists(word_index_path):
        return build_word_index()
    with open(word_index_path, 'r', encoding='utf-8-sig') as f:
        return json.load(f)


def convert_text_to_index():
    """将文本转换为索引序列(含停用词过滤)"""
    word_index = load_word_index()
    stopwords = load_stopwords()

    with open(csv_path, 'r', encoding='utf-8-sig') as f, \
            open(index_path, 'w', encoding='utf-8-sig') as f1:

        reader = csv.reader(f)
        for row in reader:
            if len(row) < 2:
                continue
            label, text = row[0], row[1]
            line_indexes = []
            words = jieba.cut(text)

            for word in words:
                # 新增:过滤停用词
                if word.strip() and word not in punk and word not in stopwords and word in word_index:
                    line_indexes.append(word_index[word])

            f1.write(json.dumps({"indexes": line_indexes, "label": label}, ensure_ascii=False) + '\n')
    print(f"已将文本转换为索引序列: {index_path}")


def load_processed_data():
    if not os.path.exists(index_path):
        convert_text_to_index()
    labels = []
    inputs = []

    with open(index_path, 'r', encoding='utf-8-sig') as f:
        for line in f:
            data = json.loads(line)
            indexes = data["indexes"]
            label = int(data["label"])

            if len(indexes) > line_length_threshold:
                processed = indexes[:line_length_threshold]
            else:
                processed = indexes + [0] * (line_length_threshold - len(indexes))

            inputs.append(processed)
            labels.append(label)

    print(f"已加载预处理数据,共 {len(labels)} 条记录")
    return inputs, labels


# 数据加载器构建
def get_data_loaders(test_size=0.2, random_state=42, batch_size=32):
    inputs, labels = load_processed_data()
    train_data, test_data, train_labels, test_labels = train_test_split(
        inputs, labels, test_size=test_size, random_state=random_state, stratify=labels
    )

    train_dataset = TensorDataset(
        torch.tensor(train_data, dtype=torch.long),
        torch.tensor(train_labels, dtype=torch.long)
    )
    test_dataset = TensorDataset(
        torch.tensor(test_data, dtype=torch.long),
        torch.tensor(test_labels, dtype=torch.long)
    )

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    print(f"数据加载完成 - 训练集: {len(train_dataset)} 条, 测试集: {len(test_dataset)} 条")
    return train_loader, test_loader


# 双向LSTM+池化模型定义
class LSTMClassifier(nn.Module):
    """基于双向LSTM+池化的文本分类模型"""

    def __init__(self, vocab_size, embedding_dim=256, hidden_dim=256,  # 隐藏层维度减半,避免过拟合
                 output_dim=2, num_layers=2, dropout=0.5):  # 提高dropout比例
        super(LSTMClassifier, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
        self.lstm = nn.LSTM(
            input_size=embedding_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            bidirectional=True,  # 双向LSTM
            dropout=dropout if num_layers > 1 else 0
        )
        # 双向LSTM输出维度为hidden_dim*2,拼接平均池化和最大池化结果
        self.fc = nn.Linear(hidden_dim * 2 * 2, output_dim)  # 2(双向)*2(两种池化)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # x shape: (batch_size, seq_len)
        embedded = self.dropout(self.embedding(x))  # (batch_size, seq_len, embedding_dim)
        lstm_out, _ = self.lstm(embedded)  # (batch_size, seq_len, hidden_dim*2)
        # 池化操作(替代仅取最后一个时间步)
        avg_pool = torch.mean(lstm_out, dim=1)  # (batch_size, hidden_dim*2)
        max_pool, _ = torch.max(lstm_out, dim=1)  # (batch_size, hidden_dim*2)
        combined = torch.cat([avg_pool, max_pool], dim=1)  # (batch_size, hidden_dim*2*2)
        return self.fc(self.dropout(combined))


# 模型训练与评估(含学习率调度和提前停止)
def train_model(epochs=50, lr=0.001, batch_size=32):
    train_loader, test_loader = get_data_loaders(batch_size=batch_size)
    vocab_size = get_vocab_size() + 1

    model = LSTMClassifier(vocab_size).to(device)
    optimizer = opt.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    # 学习率调度器:测试损失3轮不下降则乘以0.5
    scheduler = opt.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', patience=3, factor=0.5, verbose=True
    )

    best_test_acc = 0.0
    no_improve_epochs = 0  # 连续未提升轮数计数器
    early_stop_patience = 10  # 提前停止阈值

    print(f"开始模型训练,使用设备: {device}")
    for epoch in range(epochs):
        # 训练阶段
        model.train()
        train_loss, train_correct = 0.0, 0

        for texts, labels in train_loader:
            texts, labels = texts.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(texts)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            train_correct += (outputs.argmax(1) == labels).sum().item()

        # 测试阶段
        model.eval()
        test_loss, test_correct = 0.0, 0
        with torch.no_grad():
            for texts, labels in test_loader:
                texts, labels = texts.to(device), labels.to(device)
                outputs = model(texts)
                loss = criterion(outputs, labels)
                test_loss += loss.item()
                test_correct += (outputs.argmax(1) == labels).sum().item()

        # 计算指标
        train_acc = train_correct / len(train_loader.dataset)
        test_acc = test_correct / len(test_loader.dataset)
        avg_train_loss = train_loss / len(train_loader)
        avg_test_loss = test_loss / len(test_loader)

        # 打印信息
        print(f"\nEpoch {epoch + 1}/{epochs}")
        print(f"Train Loss: {avg_train_loss:.4f} | Acc: {train_acc:.4f}")
        print(f"Test Loss: {avg_test_loss:.4f} | Acc: {test_acc:.4f}")

        # 学习率调度
        scheduler.step(avg_test_loss)

        # 保存最佳模型与提前停止判断
        if test_acc > best_test_acc:
            best_test_acc = test_acc
            torch.save(model.state_dict(), model_path)
            print(f"保存最佳模型 (Test Acc: {best_test_acc:.4f}) 到 {model_path}")
            no_improve_epochs = 0  # 重置计数器
        else:
            no_improve_epochs += 1
            print(f"连续 {no_improve_epochs} 轮测试准确率未提升")
            if no_improve_epochs >= early_stop_patience:
                print(f"触发提前停止({early_stop_patience}轮未提升)")
                break

    return best_test_acc


def evaluate_model():
    _, test_loader = get_data_loaders()
    vocab_size = get_vocab_size() + 1
    model = LSTMClassifier(vocab_size).to(device)
    model.load_state_dict(torch.load(model_path))
    model.eval()

    criterion = nn.CrossEntropyLoss()
    test_loss, test_correct = 0.0, 0

    with torch.no_grad():
        for texts, labels in test_loader:
            texts, labels = texts.to(device), labels.to(device)
            outputs = model(texts)
            loss = criterion(outputs, labels)
            test_loss += loss.item()
            test_correct += (outputs.argmax(1) == labels).sum().item()

    test_acc = test_correct / len(test_loader.dataset)
    print(f"\n最终评估结果:")
    print(f"Test Loss: {test_loss / len(test_loader):.4f} | Test Acc: {test_acc:.4f}")
    return test_acc


def get_vocab_size():
    return len(load_word_index())


# 预测函数
def predict_text(text):
    word_index = load_word_index()
    stopwords = load_stopwords()
    vocab_size = get_vocab_size() + 1

    words = jieba.cut(text)
    indexes = []
    for word in words:
        if word.strip() and word not in punk and word not in stopwords and word in word_index:
            indexes.append(word_index[word])

    if len(indexes) > line_length_threshold:
        indexes = indexes[:line_length_threshold]
    else:
        indexes += [0] * (line_length_threshold - len(indexes))

    input_tensor = torch.tensor(indexes, dtype=torch.long).unsqueeze(0).to(device)
    model = LSTMClassifier(vocab_size).to(device)
    model.load_state_dict(torch.load(model_path))
    model.eval()

    with torch.no_grad():
        output = model(input_tensor)
        prediction = torch.argmax(output, dim=1).item()

    return prediction


def test_model():
    """测试模型预测功能,使用示例文本进行分类"""
    if not os.path.exists(model_path):
        print("模型文件不存在,请先训练模型")
        return

    # 测试文本:2句正面,2句负面
    test_texts = [
        "这个旅店真不行,环境脏乱差,服务态度差",
        "这个旅店真不错,主要是便宜,价格便宜质量也不错",
        "店家服务很周到,态度也很好,下次还来",
        "今天的体验非常糟糕,工作人员态度恶劣,不会再来了"
    ]

    print("\n开始模型测试...")
    for i, text in enumerate(test_texts):
        result = predict_text(text)
        sentiment = "正面" if result == 1 else "负面"
        print(f"测试文本 {i + 1}: {text}")
        print(f"预测结果: {sentiment} (标签: {result})\n")


if __name__ == '__main__':
    # 先创建必要的目录
    create_necessary_directories()
    
    # 检查CSV文件是否存在
    if not check_csv_file():
        exit(1)  # 如果CSV文件不存在,退出程序
    
    # 执行主要流程
    build_word_index()
    convert_text_to_index()
    train_model(epochs=50)
    evaluate_model()

    test_model()

十二.数据集获取

 本文使用的csv文件格式如下:

1,房间还算整齐宽敞,我住的是标准间大床房,只是灯泡有点问题,提交给店家,店家来换了一个。
1,物有所值,环境挺好的。这个价格这个质量已经很不错了。良心旅店!
0,屋里有垃圾,环境不够干净!
0,厕所的灯是坏的,晚上根本看不见!糟糕!

所以我们可以把这段文字复制粘贴给大模型,比如deepseek等,让大模型生成评论来进行训练!

提示词:

1,房间还算整齐宽敞,我住的是标准间大床房,只是灯泡有点问题,提交给店家,店家来换了一个。
1,物有所值,环境挺好的。这个价格这个质量已经很不错了。良心旅店!
0,屋里有垃圾,环境不够干净!
0,厕所的灯是坏的,晚上根本看不见!糟糕!

请你帮我生成100条类似的评论,按照如上的格式,我要存入csv文件

效果:

 


网站公告

今日签到

点亮在社区的每一天
去签到