分类模型(BERT)训练全流程

发布于:2025-07-23 ⋅ 阅读:(20) ⋅ 点赞:(0)

使用BERT实现分类模型的完整训练流程

BERT (Bidirectional Encoder Representations from Transformers) 是一种强大的预训练语言模型,在各种NLP任务中表现出色。下面我将详细梳理使用BERT实现文本分类模型的完整训练过程。

1. 准备工作

1.1 环境配置

pip install transformers torch tensorflow pandas sklearn

1.2 选择BERT版本

  • BERT-base (110M参数)
  • BERT-large (340M参数)
  • 中文BERT (如bert-base-chinese)
  • 领域特定BERT (如BioBERT, SciBERT)

2. 数据准备

2.1 数据格式

text,label
"这个电影很好看",1
"产品体验很差",0
...

2.2 数据预处理

import pandas as pd
from sklearn.model_selection import train_test_split

# 读取数据
df = pd.read_csv('data.csv')

# 划分训练集和验证集
train_df, val_df = train_test_split(df, test_size=0.2, random_state=42)

# 查看类别分布
print(train_df['label'].value_counts())

3. 使用Transformers库加载BERT

3.1 导入必要组件

from transformers import BertTokenizer, BertForSequenceClassification
from transformers import AdamW, get_linear_schedule_with_warmup
import torch
from torch.utils.data import Dataset, DataLoader

3.2 初始化Tokenizer和Model

# 选择预训练模型
MODEL_NAME = 'bert-base-chinese'  # 中文模型

# 加载分词器
tokenizer = BertTokenizer.from_pretrained(MODEL_NAME)

# 加载模型
model = BertForSequenceClassification.from_pretrained(
    MODEL_NAME,
    num_labels=len(train_df['label'].unique()),  # 类别数量
    output_attentions=False,
    output_hidden_states=False
)

# 移至GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

4. 创建数据集和数据加载器

4.1 自定义Dataset类

class TextDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_len):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len
    
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        text = str(self.texts[idx])
        label = self.labels[idx]
        
        encoding = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=self.max_len,
            return_token_type_ids=False,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt',
        )
        
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor(label, dtype=torch.long)
        }

4.2 创建数据加载器

MAX_LEN = 128  # BERT最大输入长度
BATCH_SIZE = 32

def create_data_loader(df, tokenizer, max_len, batch_size):
    ds = TextDataset(
        texts=df['text'].to_numpy(),
        labels=df['label'].to_numpy(),
        tokenizer=tokenizer,
        max_len=max_len
    )
    
    return DataLoader(
        ds,
        batch_size=batch_size,
        num_workers=4
    )

train_data_loader = create_data_loader(train_df, tokenizer, MAX_LEN, BATCH_SIZE)
val_data_loader = create_data_loader(val_df, tokenizer, MAX_LEN, BATCH_SIZE)

5. 训练准备

5.1 设置优化器和学习率调度器

EPOCHS = 3
LEARNING_RATE = 2e-5

optimizer = AdamW(model.parameters(), lr=LEARNING_RATE, correct_bias=False)
total_steps = len(train_data_loader) * EPOCHS

scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=0,
    num_training_steps=total_steps
)

loss_fn = torch.nn.CrossEntropyLoss().to(device)

5.2 训练函数

def train_epoch(model, data_loader, loss_fn, optimizer, device, scheduler, n_examples):
    model = model.train()
    losses = []
    correct_predictions = 0
    
    for d in data_loader:
        input_ids = d["input_ids"].to(device)
        attention_mask = d["attention_mask"].to(device)
        labels = d["labels"].to(device)
        
        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels
        )
        
        loss = outputs.loss
        logits = outputs.logits
        
        _, preds = torch.max(logits, dim=1)
        correct_predictions += torch.sum(preds == labels)
        losses.append(loss.item())
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()
    
    return correct_predictions.double() / n_examples, np.mean(losses)

5.3 评估函数

def eval_model(model, data_loader, loss_fn, device, n_examples):
    model = model.eval()
    losses = []
    correct_predictions = 0
    
    with torch.no_grad():
        for d in data_loader:
            input_ids = d["input_ids"].to(device)
            attention_mask = d["attention_mask"].to(device)
            labels = d["labels"].to(device)
            
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels
            )
            
            loss = outputs.loss
            logits = outputs.logits
            
            _, preds = torch.max(logits, dim=1)
            correct_predictions += torch.sum(preds == labels)
            losses.append(loss.item())
    
    return correct_predictions.double() / n_examples, np.mean(losses)

6. 训练循环

from collections import defaultdict
import numpy as np

history = defaultdict(list)
best_accuracy = 0

for epoch in range(EPOCHS):
    print(f'Epoch {epoch + 1}/{EPOCHS}')
    print('-' * 10)
    
    train_acc, train_loss = train_epoch(
        model,
        train_data_loader,
        loss_fn,
        optimizer,
        device,
        scheduler,
        len(train_df)
    )
    
    print(f'Train loss {train_loss} accuracy {train_acc}')
    
    val_acc, val_loss = eval_model(
        model,
        val_data_loader,
        loss_fn,
        device,
        len(val_df)
    )
    
    print(f'Val loss {val_loss} accuracy {val_acc}')
    print()
    
    history['train_acc'].append(train_acc)
    history['train_loss'].append(train_loss)
    history['val_acc'].append(val_acc)
    history['val_loss'].append(val_loss)
    
    if val_acc > best_accuracy:
        torch.save(model.state_dict(), 'best_model_state.bin')
        best_accuracy = val_acc

7. 模型评估与预测

7.1 加载最佳模型

model = BertForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=2)
model.load_state_dict(torch.load('best_model_state.bin'))
model = model.to(device)

7.2 预测函数

def get_predictions(model, data_loader):
    model = model.eval()
    review_texts = []
    predictions = []
    prediction_probs = []
    real_values = []
    
    with torch.no_grad():
        for d in data_loader:
            texts = d["text"]
            input_ids = d["input_ids"].to(device)
            attention_mask = d["attention_mask"].to(device)
            labels = d["labels"].to(device)
            
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask
            )
            
            _, preds = torch.max(outputs.logits, dim=1)
            
            probs = torch.softmax(outputs.logits, dim=1)
            
            review_texts.extend(texts)
            predictions.extend(preds)
            prediction_probs.extend(probs)
            real_values.extend(labels)
    
    predictions = torch.stack(predictions).cpu()
    prediction_probs = torch.stack(prediction_probs).cpu()
    real_values = torch.stack(real_values).cpu()
    
    return review_texts, predictions, prediction_probs, real_values

7.3 生成分类报告

from sklearn.metrics import classification_report, confusion_matrix

y_review_texts, y_pred, y_pred_probs, y_test = get_predictions(model, val_data_loader)

print(classification_report(y_test, y_pred))

8. 模型保存与部署

8.1 保存整个模型

model.save_pretrained("./my_bert_classifier")
tokenizer.save_pretrained("./my_bert_classifier")

8.2 创建预测API示例

from flask import Flask, request, jsonify
import torch

app = Flask(__name__)

# 加载模型和tokenizer
model = BertForSequenceClassification.from_pretrained('./my_bert_classifier')
tokenizer = BertTokenizer.from_pretrained('./my_bert_classifier')
model.eval()

@app.route('/predict', methods=['POST'])
def predict():
    data = request.get_json()
    text = data['text']
    
    encoded_text = tokenizer.encode_plus(
        text,
        max_length=128,
        add_special_tokens=True,
        return_token_type_ids=False,
        padding='max_length',
        return_attention_mask=True,
        return_tensors='pt',
    )
    
    input_ids = encoded_text['input_ids']
    attention_mask = encoded_text['attention_mask']
    
    with torch.no_grad():
        output = model(input_ids, attention_mask)
    
    _, prediction = torch.max(output.logits, dim=1)
    prob = torch.softmax(output.logits, dim=1)
    
    return jsonify({
        'prediction': prediction.item(),
        'probability': prob[0][prediction.item()].item()
    })

if __name__ == '__main__':
    app.run(host='0.0.0.0', port=5000)

关键注意事项

  1. 学习率选择:BERT通常使用很小的学习率(2e-5到5e-5)
  2. 批量大小:根据GPU内存选择,通常16-64
  3. 训练轮次:BERT微调通常3-5个epoch就足够
  4. 序列长度:根据任务调整MAX_LEN,太长会浪费计算资源
  5. 类别不平衡:可使用class_weight参数调整损失函数
  6. GPU使用:建议使用CUDA加速训练

通过以上流程,您可以完整地实现一个基于BERT的文本分类模型,从数据准备到训练评估,最后到部署应用。


网站公告

今日签到

点亮在社区的每一天
去签到