BERT的中文问答系统17

发布于:2024-10-18 ⋅ 阅读:(15) ⋅ 点赞:(0)
import os
import json
import jsonlines
import torch
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import BertModel, BertTokenizer
import tkinter as tk
from tkinter import filedialog, messagebox, ttk
import logging
from difflib import SequenceMatcher
from datetime import datetime
import matplotlib.pyplot as plt
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
import pandas as pd
from tqdm import tqdm

# 获取项目根目录
PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))

# 配置日志
LOGS_DIR = os.path.join(PROJECT_ROOT, 'logs')
os.makedirs(LOGS_DIR, exist_ok=True)

def setup_logging():
    log_file = os.path.join(LOGS_DIR, datetime.now().strftime('%Y-%m-%d/%H-%M-%S/羲和.txt'))
    os.makedirs(os.path.dirname(log_file), exist_ok=True)
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s',
        handlers=[
            logging.FileHandler(log_file),
            logging.StreamHandler()
        ]
    )

setup_logging()

# 数据集类
class XihuaDataset(Dataset):
    def __init__(self, file_path, tokenizer, max_length=128):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.data = self.load_data(file_path)

    def load_data(self, file_path):
        data = []
        if file_path.endswith('.jsonl'):
            try:
                with jsonlines.open(file_path) as reader:
                    for i, item in enumerate(reader):
                        if self.validate_item(item):
                            data.append(item)
            except (FileNotFoundError, jsonlines.jsonlines.InvalidLineError) as e:
                logging.warning(f"跳过无效文件 {file_path}: {e}")
        elif file_path.endswith('.json'):
            try:
                with open(file_path, 'r') as f:
                    data = [item for item in json.load(f) if self.validate_item(item)]
            except (FileNotFoundError, json.JSONDecodeError) as e:
                logging.warning(f"跳过无效文件 {file_path}: {e}")
        return data

    def validate_item(self, item):
        required_keys = ['question', 'xihe_answers', 'ling_answers']
        if all(key in item for key in required_keys):
            return True
        logging.warning(f"跳过无效项: 缺少必要键 {required_keys}")
        return False

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        question = item['question']

        # 确保 xihe_answers 和 ling_answers 都有值
        xihe_answer = item.get('xihe_answers', [])
        ling_answer = item.get('ling_answers', [])

        if not xihe_answer and ling_answer:
            xihe_answer = ling_answer
        elif not ling_answer and xihe_answer:
            ling_answer = xihe_answer

        xihe_answer = xihe_answer[0] if xihe_answer else ""
        ling_answer = ling_answer[0] if ling_answer else ""

        try:
            inputs = self.tokenizer(question, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)
            xihe_inputs = self.tokenizer(xihe_answer, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)
            ling_inputs = self.tokenizer(ling_answer, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)
        except Exception as e:
            logging.warning(f"跳过无效项 {idx}: {e}")
            return self.__getitem__((idx + 1) % len(self.data))

        return {
            'input_ids': inputs['input_ids'].squeeze(),
            'attention_mask': inputs['attention_mask'].squeeze(),
            'xihe_input_ids': xihe_inputs['input_ids'].squeeze(),
            'xihe_attention_mask': xihe_inputs['attention_mask'].squeeze(),
            'ling_input_ids': ling_inputs['input_ids'].squeeze(),
            'ling_attention_mask': ling_inputs['attention_mask'].squeeze(),
            'xihe_answer': xihe_answer,
            'ling_answer': ling_answer
        }

# 获取数据加载器
def get_data_loader(file_path, tokenizer, batch_size=8, max_length=128):
    try:
        dataset = XihuaDataset(file_path, tokenizer, max_length)
        return DataLoader(dataset, batch_size=batch_size, shuffle=True)
    except Exception as e:
        logging.error(f"获取数据加载器失败: {e}")
        raise

# 模型定义
class XihuaModel(torch.nn.Module):
    def __init__(self, pretrained_model_name='F:/models/bert-base-chinese'):
        super(XihuaModel, self).__init__()
        self.bert = BertModel.from_pretrained(pretrained_model_name)
        self.classifier = torch.nn.Linear(self.bert.config.hidden_size, 1)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.pooler_output
        logits = self.classifier(pooled_output)
        return logits

# 训练函数
def train(model, data_loader, optimizer, criterion, device):
    model.train()
    total_loss = 0.0
    losses = []
    for batch in tqdm(data_loader, desc="Training"):
        try:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            xihe_input_ids = batch['xihe_input_ids'].to(device)
            xihe_attention_mask = batch['xihe_attention_mask'].to(device)
            ling_input_ids = batch['ling_input_ids'].to(device)
            ling_attention_mask = batch['ling_attention_mask'].to(device)

            optimizer.zero_grad()
            xihe_logits = model(xihe_input_ids, xihe_attention_mask)
            ling_logits = model(ling_input_ids, ling_attention_mask)

            xihe_labels = torch.ones(xihe_logits.size(0), 1).to(device)
            ling_labels = torch.zeros(ling_logits.size(0), 1).to(device)

            loss = criterion(xihe_logits, xihe_labels) + criterion(ling_logits, ling_labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            losses.append(loss.item())
        except Exception as e:
            logging.warning(f"跳过无效批次: {e}")

    return total_loss / len(data_loader), losses

# 主训练函数
def main_train(retrain=False):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logging.info(f'Using device: {device}')

    try:
        tokenizer = BertTokenizer.from_pretrained('F:/models/bert-base-chinese')
        model = XihuaModel(pretrained_model_name='F:/models/bert-base-chinese').to(device)

        if retrain:
            model.load_state_dict(torch.load(os.path.join(PROJECT_ROOT, 'models/xihua_model.pth'), map_location=device, weights_only=True))

        optimizer = optim.AdamW(model.parameters(), lr=1e-5)
        criterion = torch.nn.BCEWithLogitsLoss()

        train_data_loader = get_data_loader(os.path.join(PROJECT_ROOT, 'data/train_data.jsonl'), tokenizer, batch_size=8, max_length=128)

        num_epochs = 3
        for epoch in range(num_epochs):
            train_loss, losses = train(model, train_data_loader, optimizer, criterion, device)
            logging.info(f'Epoch [{epoch+1}/{num_epochs}], Loss: {train_loss:.8f}')
            plot_losses(losses)

        torch.save(model.state_dict(), os.path.join(PROJECT_ROOT, 'models/xihua_model.pth'))
        logging.info("模型训练完成并保存")
    except Exception as e:
        logging.error(f"主训练函数失败: {e}")
        raise

# 绘制损失图
def plot_losses(losses):
    fig, ax = plt.subplots()
    ax.plot(losses)
    ax.set_xlabel('Batch')
    ax.set_ylabel('Loss')
    ax.set_title('Training Loss')
    canvas = FigureCanvasTkAgg(fig, master=root)
    canvas.draw()
    canvas.get_tk_widget().pack(side=tk.TOP, fill=tk.BOTH, expand=1)

# GUI界面
class XihuaChatbotGUI:
    def __init__(self, root):
        self.root = root
        self.root.title("羲和聊天机器人")

        self.tokenizer = BertTokenizer.from_pretrained('F:/models/bert-base-chinese')
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model = XihuaModel(pretrained_model_name='F:/models/bert-base-chinese').to(self.device)

        model_path = os.path.join(PROJECT_ROOT, 'models/xihua_model.pth')
        if not os.path.exists(model_path):
            messagebox.showinfo("模型未找到", "未找到现有模型,将开始训练新的模型")
            self.train_model()
        else:
            self.load_model()
            self.model.eval()

        # 加载训练数据集以便在获取答案时使用
        self.data = self.load_data(os.path.join(PROJECT_ROOT, 'data/train_data.jsonl'))

        self.history = []

        self.train_mode_var = tk.BooleanVar()  # 初始化 train_mode_var
        self.status_label = tk.Label(self.root, text="")  # 初始化 status_label

        self.create_widgets()

    def create_widgets(self):
        # 设置窗口图标
        icon_path = os.path.join(PROJECT_ROOT, 'icons/xihe.ico')
        if os.path.exists(icon_path):
            self.root.iconbitmap(icon_path)

        # 问题输入框
        self.question_frame = tk.Frame(self.root)
        self.question_frame.pack(pady=10)

        self.question_label = tk.Label(self.question_frame, text="问题:")
        self.question_label.pack(side=tk.LEFT)

        self.question_entry = tk.Entry(self.question_frame, width=50)
        self.question_entry.pack(side=tk.LEFT)

        self.answer_button = tk.Button(self.question_frame, text="获取回答", command=self.get_answer)
        self.answer_button.pack(side=tk.LEFT)

        # 回答显示区
        self.answer_frame = tk.Frame(self.root)
        self.answer_frame.pack(pady=10)

        self.answer_label = tk.Label(self.answer_frame, text="回答:")
        self.answer_label.pack()

        self.answer_text = tk.Text(self.answer_frame, height=10, width=50)
        self.answer_text.pack()

        # 历史记录
        self.history_frame = tk.Frame(self.root)
        self.history_frame.pack(pady=10)

        self.history_label = tk.Label(self.history_frame, text="历史记录:")
        self.history_label.pack()

        self.history_text = tk.Text(self.history_frame, height=10, width=50)
        self.history_text.pack()

        # 训练模式
        self.train_mode_frame = tk.Frame(self.root)
        self.train_mode_frame.pack(pady=10)

        self.train_mode_checkbutton = tk.Checkbutton(self.train_mode_frame, text="继续训练现有模型", variable=self.train_mode_var)
        self.train_mode_checkbutton.pack(side=tk.LEFT)

        self.train_button = tk.Button(self.train_mode_frame, text="训练模型", command=self.train_model)
        self.train_button.pack(side=tk.LEFT)

        self.retrain_button = tk.Button(self.train_mode_frame, text="重新训练模型", command=lambda: self.train_model(retrain=True))
        self.retrain_button.pack(side=tk.LEFT)

        # 进度条
        self.progress = ttk.Progressbar(self.root, orient='horizontal', mode='determinate')
        self.progress.pack(fill=tk.X, padx=10, pady=10)

        # 状态信息
        self.status_label.pack()

    def get_answer(self):
        question = self.question_entry.get()
        if not question:
            messagebox.showwarning("输入错误", "请输入问题")
            return

        try:
            inputs = self.tokenizer(question, return_tensors='pt', padding='max_length', truncation=True, max_length=128)
            with torch.no_grad():
                input_ids = inputs['input_ids'].to(self.device)
                attention_mask = inputs['attention_mask'].to(self.device)
                logits = self.model(input_ids, attention_mask)
            
            if logits.item() > 0:
                answer_type = "羲和回答"
            else:
                answer_type = "零回答"

            specific_answer = self.get_specific_answer(question, answer_type)

            self.answer_text.delete(1.0, tk.END)
            self.answer_text.insert(tk.END, f"{answer_type}\n{specific_answer}")

            self.history.append((question, specific_answer))
            self.update_history()
        except Exception as e:
            logging.error(f"获取回答失败: {e}")
            messagebox.showerror("获取回答失败", f"获取回答失败: {e}")

    def update_history(self):
        self.history_text.delete(1.0, tk.END)
        for q, a in self.history:
            self.history_text.insert(tk.END, f"Q: {q}\nA: {a}\n\n")

    def get_specific_answer(self, question, answer_type):
        # 使用模糊匹配查找最相似的问题
        best_match = None
        best_ratio = 0.0
        for item in self.data:
            ratio = SequenceMatcher(None, question, item['question']).ratio()
            if ratio > best_ratio:
                best_ratio = ratio
                best_match = item

        if best_match:
            xihe_answer = best_match.get('xihe_answers', [])
            ling_answer = best_match.get('ling_answers', [])

            if not xihe_answer and ling_answer:
                xihe_answer = ling_answer
            elif not ling_answer and xihe_answer:
                ling_answer = xihe_answer

            if answer_type == "羲和回答":
                return xihe_answer[0] if xihe_answer else ling_answer[0]
            else:
                return ling_answer[0] if ling_answer else xihe_answer[0]
        return "这个我也不清楚,你问问零吧"

    def load_data(self, file_path):
        try:
            data = []
            if file_path.endswith('.jsonl'):
                with jsonlines.open(file_path) as reader:
                    for i, item in enumerate(reader):
                        if self.validate_item(item):
                            data.append(item)
            elif file_path.endswith('.json'):
                with open(file_path, 'r') as f:
                    data = [item for item in json.load(f) if self.validate_item(item)]
            return data
        except Exception as e:
            logging.error(f"加载数据失败: {e}")
            raise

    def validate_item(self, item):
        required_keys = ['question', 'xihe_answers', 'ling_answers']
        if all(key in item for key in required_keys):
            return True
        logging.warning(f"跳过无效项: 缺少必要键 {required_keys}")
        return False

    def load_model(self):
        model_path = os.path.join(PROJECT_ROOT, 'models/xihua_model.pth')
        if os.path.exists(model_path):
            try:
                self.model.load_state_dict(torch.load(model_path, map_location=self.device, weights_only=True))
                logging.info("加载现有模型")
            except Exception as e:
                logging.error(f"加载模型失败: {e}")
                messagebox.showerror("加载失败", f"加载模型失败: {e}")
        else:
            logging.info("没有找到现有模型,将使用预训练模型")

    def train_model(self, retrain=False):
        if not hasattr(self, 'train_mode_var'):
            self.train_mode_var = tk.BooleanVar()
        if not hasattr(self, 'status_label'):
            self.status_label = tk.Label(self.root, text="")
            self.status_label.pack()

        try:
            file_path = filedialog.askopenfilename(filetypes=[("JSONL files", "*.jsonl"), ("JSON files", "*.json")])
            if not file_path:
                messagebox.showwarning("文件选择错误", "请选择一个有效的数据文件")
                return

            dataset = XihuaDataset(file_path, self.tokenizer)
            data_loader = DataLoader(dataset, batch_size=8, shuffle=True)
            
            # 加载已训练的模型权重
            if retrain or self.train_mode_var.get():
                self.model.load_state_dict(torch.load(os.path.join(PROJECT_ROOT, 'models/xihua_model.pth'), map_location=self.device, weights_only=True))
                self.model.to(self.device)
                self.model.train()

            optimizer = torch.optim.AdamW(self.model.parameters(), lr=1e-5)
            criterion = torch.nn.BCEWithLogitsLoss()
            num_epochs = 3
            for epoch in range(num_epochs):
                self.status_label.config(text=f"正在训练 Epoch {epoch+1}/{num_epochs}")
                self.root.update_idletasks()
                train_loss, losses = train(self.model, data_loader, optimizer, criterion, self.device)
                logging.info(f'Epoch [{epoch+1}/{num_epochs}], Loss: {train_loss:.4f}')
                plot_losses(losses)
                self.progress['value'] = (epoch + 1) / num_epochs * 100
                self.root.update_idletasks()

            torch.save(self.model.state_dict(), os.path.join(PROJECT_ROOT, 'models/xihua_model.pth'))
            logging.info("模型训练完成并保存")
            messagebox.showinfo("训练完成", "模型训练完成并保存")
            self.status_label.config(text="训练完成")
        except Exception as e:
            logging.error(f"模型训练失败: {e}")
            messagebox.showerror("训练失败", f"模型训练失败: {e}")
            self.status_label.config(text="训练失败")

# 主函数
if __name__ == "__main__":
    # 启动GUI
    root = tk.Tk()
    app = XihuaChatbotGUI(root)
    root.mainloop()

完善GUI更多的用户交互功能,使用更高效的算法和数据结构,以提高代码的性能,添加安装指南、使用指南、API文档,添加更多的数据清洗和标准化步骤,以提高模型的性能。此外,可以考虑使用更高效的数据加载和处理方法,添加更多的异常处理来处理可能的运行时错误,并确保日志记录包含足够的信息来帮助调试和监控,使用更高效的优化器和损失函数,使用加密技术来保护敏感数据
project_root/
├── data/
│ └── train_data.jsonl
├── models/
│ └── xihua_model.pth
├── logs/
│ └── <date_time>/
│ └── 羲和.txt
├── icons/
│ ├── xihe.png
│ └── ling.png
└── main.py

README.md
内容写入文件

echo "# 羲和聊天机器人

## 项目简介

羲和聊天机器人是一个基于BERT模型的中文聊天机器人,能够根据用户输入的问题提供相应的回答。项目包括数据加载、模型训练、模型推理和图形用户界面(GUI)等功能。

## 项目结构

\`\`\`
project_root/
├── data/
│   └── train_data.jsonl
├── models/
│   └── xihua_model.pth
├── logs/
│   └── <date_time>/
│       └── 羲和.txt
├── icons/
│   ├── xihe.png
│   └── ling.png
└── main.py
\`\`\`

## 安装指南

1. **安装Python**:确保已安装Python 3.7或更高版本。
2. **安装依赖库**:
   \`\`\`sh
   pip install torch transformers pandas tqdm matplotlib
   \`\`\`
3. **下载预训练模型**:从Hugging Face下载 \`bert-base-chinese\` 模型并放置在 \`F:/models/\` 目录下。

## 使用指南

1. **启动程序**:运行 \`main.py\` 文件启动程序。
   \`\`\`sh
   python main.py
   \`\`\`
2. **输入问题**:在问题输入框中输入问题,点击“获取回答”按钮获取回答。
3. **查看历史记录**:在历史记录区域查看之前的问答记录。
4. **训练模型**:
   - 点击“训练模型”按钮选择训练数据文件。
   - 选择是否继续训练现有模型。
   - 点击“重新训练模型”按钮重新训练模型。

## 项目目录结构

- **\`data/\`**:存放训练数据文件。
- **\`models/\`**:存放模型文件。
- **\`logs/\`**:存放日志文件。
- **\`icons/\`**:存放图标文件。
- **\`main.py\`**:包含所有代码的主文件。

## 代码说明

### 数据集类 (\`XihuaDataset\`)

- **\`load_data\`**:加载数据文件,支持 JSONL 和 JSON 格式。
- **\`validate_item\`**:验证数据项是否有效。
- **\`__getitem__\`**:返回数据项的输入和标签。

### 模型类 (\`XihuaModel\`)

- **\`forward\`**:前向传播,返回模型的输出。

### 训练函数 (\`train\`)

- **\`train\`**:训练模型,返回平均损失和每批次的损失。

### 主训练函数 (\`main_train\`)

- **\`main_train\`**:主训练函数,调用训练函数进行模型训练。

### 绘制损失图 (\`plot_losses\`)

- **\`plot_losses\`**:绘制训练过程中的损失曲线。

### GUI界面 (\`XihuaChatbotGUI\`)

- **\`create_widgets\`**:创建GUI界面的各个组件。
- **\`get_answer\`**:获取模型的回答。
- **\`update_history\`**:更新历史记录。
- **\`get_specific_answer\`**:根据问题获取具体的回答。
- **\`load_data\`**:加载训练数据。
- **\`validate_item\`**:验证数据项是否有效。
- **\`load_model\`**:加载模型。
- **\`train_model\`**:训练模型。

## 联系方式

如有任何问题或建议,请联系 [554687453@qq.com]。

## 许可证

本项目采用MIT许可证,详情参见 [LICENSE](LICENSE) 文件。" > project_root/README.md