# 基于BERT的文本分类

发布于:2025-04-11 ⋅ 阅读:(42) ⋅ 点赞:(0)

基于BERT的文本分类项目的实现

一、项目背景

该文本分类项目主要是情感分析,二分类问题,以下是大致流程及部分代码示例:


二、数据集介绍

2.1 数据集基本信息

数据集 自定义
类型 二分类(正面/负面)
样本量 训练集 + 验证集 + 测试集
文本长度 平均x字(最大x字)
领域 商品评论、影视评论
# 加载数据集
dataset = pd.read_csv('data/train.txt', sep='\t')
print(dataset['train'][0])
# 输出:{'text': '这个手机性价比超高,拍照效果惊艳!', 'label': 1}

2.2 数据分析

2.2.1 句子长度分布
import matplotlib.pyplot as plt

def analyze_length(texts):
    lengths = [len(t) for t in texts]
    plt.figure(figsize=(12,5))
    plt.hist(lengths, bins=30, range=(0,256), color='blue', alpha=0.7)
    plt.title("文本长度分布", fontsize=14)
    plt.xlabel("字符数")
    plt.ylabel("样本量")
    plt.show()

analyze_length(dataset['train']['text'])
2.2.2 标签分布
import pandas as pd

pd.Series(dataset['train']['label']).value_counts().plot(
    kind='pie',
    autopct='%1.1f%%',
    title='类别分布(0-负面 1-正面)'
)
plt.show()
2.2.3 类别平衡处理
from torch.utils.data import WeightedRandomSampler

# 计算类别权重
labels = dataset['train']['label']
class_weights = 1 / torch.Tensor([len(labels)-sum(labels), sum(labels)])
sampler = WeightedRandomSampler(
    weights=[class_weights[label] for label in labels],
    num_samples=len(labels),
    replacement=True
)

三、数据处理

3.1 BERT分词器

from transformers import BertTokenizer

tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')

def collate_fn(batch):
    texts = [item['text'] for item in batch]
    labels = [item['label'] for item in batch]
    
    # BERT编码
    inputs = tokenizer(
        texts,
        padding=True,
        truncation=True,
        max_length=256,
        return_tensors='pt'
    )
    return {
        'input_ids': inputs['input_ids'],
        'attention_mask': inputs['attention_mask'],
        'labels': torch.LongTensor(labels)
    }

3.2 数据加载器

from torch.utils.data import DataLoader

train_loader = DataLoader(
    dataset['train'],
    batch_size=32,
    collate_fn=collate_fn,
    sampler=sampler
)

val_loader = DataLoader(
    dataset['validation'],
    batch_size=32,
    collate_fn=collate_fn
)

四、模型构建

4.1 BERT分类模型

import torch.nn as nn
from transformers import BertModel

class BertClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.bert = BertModel.from_pretrained('bert-base-chinese')
        self.dropout = nn.Dropout(0.1)
        self.fc = nn.Linear(768, 2)
        
    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids, attention_mask)
        pooled = self.dropout(outputs.pooler_output)
        return self.fc(pooled)

4.2 模型配置

import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = BertClassifier().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
criterion = nn.CrossEntropyLoss()

五、模型训练与验证

5.1 训练流程

from tqdm import tqdm

def train_epoch(model, loader):
    model.train()
    total_loss = 0
    for batch in tqdm(loader):
        optimizer.zero_grad()
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        
        outputs = model(input_ids, attention_mask)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    return total_loss / len(loader)

5.2 验证流程

def evaluate(model, loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch in loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            
            outputs = model(input_ids, attention_mask)
            preds = torch.argmax(outputs, dim=1)
            
            correct += (preds == labels).sum().item()
            total += len(labels)
    return correct / total

六、实验结果

6.1 评估指标

Epoch 训练Loss 验证准确率 测试准确率
# 绘制混淆矩阵
from sklearn.metrics import confusion_matrix
import seaborn as sns

def plot_confusion_matrix(loader):
    y_true = []
    y_pred = []
    model.eval()
    with torch.no_grad():
        for batch in loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            
            outputs = model(input_ids, attention_mask)
            preds = torch.argmax(outputs, dim=1)
            
            y_true.extend(labels.cpu().numpy())
            y_pred.extend(preds.cpu().numpy())
    
    cm = confusion_matrix(y_true, y_pred)
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
    plt.title('混淆矩阵')
    plt.xlabel('预测标签')
    plt.ylabel('真实标签')
    plt.show()

plot_confusion_matrix(test_loader)

6.2 学习曲线

# 记录训练过程
from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter()
for epoch in range(3):
    train_loss = train_epoch(model, train_loader)
    val_acc = evaluate(model, val_loader)
    writer.add_scalar('Loss/Train', train_loss, epoch)
    writer.add_scalar('Accuracy/Validation', val_acc, epoch)

七、流程架构图

原始文本
分词编码
BERT特征提取
全连接分类
损失计算
反向传播
模型评估


网站公告

今日签到

点亮在社区的每一天
去签到