LLM基础7_用于文本分类的微调

发布于:2025-06-13 ⋅ 阅读:(22) ⋅ 点赞:(0)

基于GitHub项目https://github.com/datawhalechina/llms-from-scratch-cn

微调的概念

  • 预训练:模型在大规模通用文本上学习语言模式(如GPT在互联网文本上训练)

  • 微调:在预训练基础上,用特定领域数据继续训练模型

为什么需要微调?

  1. 领域适应:通用模型在专业领域表现不佳

  2. 任务定制:使模型适应特定任务(如分类、情感分析)

  3. 性能提升:微调后模型在特定任务上表现更好

  4. 数据效率:比从头训练节省90%以上的数据量

文本分类任务概览

文本分类是将文本分配到预定义类别的任务

常见应用场景

  • 情感分析(正面/负面/中性)

  • 主题分类(体育/政治/科技)

  • 垃圾邮件检测

  • 意图识别(客服场景)

  • 新闻分类

微调流程详解

1. 准备领域数据

  • 数据收集:获取与任务相关的文本

  • 数据标注:人工或半自动标注类别

  • 数据格式

# CSV格式示例
text,label
"这个产品太好用了",positive
"服务太差,再也不买了",negative
"手机电池续航一般",neutral

2. 添加分类头

  • 预训练模型:提供基础语言理解能力

  • 分类头:添加在模型顶部的简单神经网络

from transformers import BertForSequenceClassification

# 加载预训练模型
model = BertForSequenceClassification.from_pretrained(
    "bert-base-chinese",
    num_labels=3  # 情感分类的类别数
)

3. 训练调整

  • 冻结参数:只训练分类头(轻量微调)
  • 全参数训练:更新所有参数(效果更好但资源消耗大)
from transformers import Trainer, TrainingArguments

training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=3,
    per_device_train_batch_size=16,
    evaluation_strategy="epoch"
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset
)

trainer.train()

微调技术细节

学习率策略

  • 预热学习率:开始小学习率,逐渐增大

  • 衰减学习率:后期减小学习率

from transformers import get_linear_schedule_with_warmup

optimizer = AdamW(model.parameters(), lr=5e-5)
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=100,
    num_training_steps=1000
)

类别不平衡处理

当各类别样本数差异大时:

  1. 重采样:过采样少数类,欠采样多数类

  2. 类别权重:在损失函数中增加少数类权重

from sklearn.utils.class_weight import compute_class_weight

class_weights = compute_class_weight('balanced', classes, labels)

数据增强技巧

  • 回译:中->英->中生成同义句

  • 同义词替换:使用词嵌入替换同义词

  • 随机插入/删除:增加文本多样性

 评估与优化

指标 公式 适用场景
准确率 (TP+TN)/(TP+FP+FN+TN) 类别平衡
F1值 2(PrecisionRecall)/(Precision+Recall) 类别不平衡
AUC ROC曲线下面积 整体性能

常见问题及解决方案

  1. 过拟合

    • 增加Dropout率

    • 添加L2正则化

    • 早停(Early Stopping)

  2. 欠拟合

    • 增加训练数据

    • 减少正则化

    • 延长训练时间

  3. 部署问题

    • 模型量化(减小模型大小)

    • ONNX格式转换(加速推理)

 实战案例:新闻分类

数据集:THUCNews中文新闻数据集

  • 10个类别:体育、财经、房产、教育等

  • 每类6500条数据,共6.5万条

# 1. 加载预训练模型
from transformers import BertTokenizer, BertForSequenceClassification

tokenizer = BertTokenizer.from_pretrained("bert-base-chinese")
model = BertForSequenceClassification.from_pretrained(
    "bert-base-chinese",
    num_labels=10
)

# 2. 准备数据
def preprocess_function(examples):
    return tokenizer(
        examples["text"],
        padding="max_length",
        truncation=True,
        max_length=128
    )

from datasets import load_dataset
dataset = load_dataset("thucnews")
dataset = dataset.map(preprocess_function, batched=True)

# 3. 训练配置
from transformers import TrainingArguments, Trainer

training_args = TrainingArguments(
    output_dir="./news_classifier",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=32,
    num_train_epochs=3,
    weight_decay=0.01,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["validation"],
)

# 4. 开始训练
trainer.train()

# 5. 评估
results = trainer.evaluate()
print(results)

高级技巧

少样本学习(Few-shot Learning)

当标注数据很少时:

1.提示工程(Prompt Engineering)

文本:"苹果发布新款iPhone"
提示:这是一条关于[科技]的新闻

2.模式利用训练(Pattern-Exploiting Training)

  • 将分类任务转化为完形填空

  • "这条新闻的主题是____" → 模型预测[MASK]位置

知识蒸馏

  • 教师模型:大型高精度模型

  • 学生模型:小型高效模型

  • 过程:学生模型学习教师模型的输出分布