基于 LoRA的广义知识蒸馏(GKD)训练

发布于:2025-08-17 ⋅ 阅读:(16) ⋅ 点赞:(0)

基于 LoRA的广义知识蒸馏(GKD)训练

flyfish

通过参数高效的 LoRA(低秩适应)技术,结合广义知识蒸馏(GKD)方法,让小尺寸的学生模型(如 Qwen2-0.5B-Instruct)高效学习大尺寸教师模型(如 Qwen2-1.5B-Instruct)的知识和能力,最终在减少计算资源消耗的前提下,提升小模型的对话性能,使其接近大模型的水平。

python examples/scripts/gkd.py \
    --model_name_or_path Qwen/Qwen2-0.5B-Instruct \  # 学生模型:小模型,待蒸馏的模型
    --teacher_model_name_or_path Qwen/Qwen2-1.5B-Instruct \  # 教师模型:大模型,提供知识的模型
    --dataset_name trl-lib/chatbot_arena_completions \  # 训练数据集:对话竞技场数据(含高质量对话)
    --learning_rate 2e-4 \  # 学习率:LoRA微调通常用更大的学习率(全量训练一般2e-5)
    --per_device_train_batch_size 4 \  # 单设备训练批次大小
    --gradient_accumulation_steps 8 \  # 梯度累积步数:总批次=4*8=32(节省显存)
    --output_dir gkd-model \  # 模型保存路径
    --num_train_epochs 1 \  # 训练轮次
    --push_to_hub \  # 训练后推送到Hugging Face Hub
    --gradient_checkpointing \  # 启用梯度检查点:牺牲少量速度换显存
    --use_peft \  # 启用PEFT(参数高效微调)框架,这里用于LoRA
    --lora_r 64 \  # LoRA的秩(秩越低,参数量越少)
    --lora_alpha 16  # LoRA的缩放系数(控制更新幅度)

解析参数 → 加载模型 / Tokenizer → 加载数据集 → 初始化 GKD 训练器 → 执行训练 → 保存模型。

# 导入必要的库
# 加载数据集的工具
from datasets import load_dataset
# 加载分词器和生成配置的工具
from transformers import AutoTokenizer, GenerationConfig

# 从trl库导入GKD相关的配置、训练器和工具
from trl import (
    GKDConfig,  # GKD训练的核心配置类
    GKDTrainer,  # GKD训练器,用于实现广义知识蒸馏
    LogCompletionsCallback,  # 记录生成结果的回调函数,用于评估
    ModelConfig,  # 模型相关配置(如LoRA参数、量化设置等)
    ScriptArguments,  # 脚本级参数(如数据集路径、分裂等)
    TrlParser,  # trl库专用的参数解析器
    get_kbit_device_map,  # 获取量化模型的设备映射(自动分配GPU/CPU)
    get_peft_config,  # 获取PEFT配置(如LoRA参数)
    get_quantization_config,  # 获取量化配置(如4/8位量化)
)


if __name__ == "__main__":
    # 初始化参数解析器,支持解析三类配置:脚本参数、GKD训练配置、模型配置
    parser = TrlParser((ScriptArguments, GKDConfig, ModelConfig))
    # 解析命令行参数,得到三个配置对象
    # script_args:数据集路径、分裂等脚本级参数
    # training_args:GKD训练的核心参数(学习率、批次大小等)
    # model_args:模型相关参数(LoRA配置、量化设置等)
    script_args, training_args, model_args = parser.parse_args_and_config()

    ################
    # 模型与分词器配置
    ################
    # 根据model_args获取量化配置(如4位/8位量化),用于减少显存占用
    quantization_config = get_quantization_config(model_args)
    
    # 定义学生模型的初始化参数
    model_kwargs = dict(
        revision=model_args.model_revision,  # 模型版本(如特定commit哈希)
        trust_remote_code=model_args.trust_remote_code,  # 是否信任模型的自定义代码(如非标准架构)
        attn_implementation=model_args.attn_implementation,  # 注意力实现方式(如flash attention加速)
        torch_dtype=model_args.torch_dtype,  # 数据类型(如float16/bfloat16,节省显存)
        # 启用梯度检查点时禁用缓存(两者冲突),否则启用缓存加速
        use_cache=False if training_args.gradient_checkpointing else True,
        # 量化时自动分配设备(GPU/CPU),非量化时不指定
        device_map=get_kbit_device_map() if quantization_config is not None else None,
        quantization_config=quantization_config,  # 量化配置(如4位量化参数)
    )
    # 将学生模型参数传递给训练配置
    training_args.model_init_kwargs = model_kwargs

    # 定义教师模型的初始化参数(与学生模型类似,但有细微差别)
    teacher_model_kwargs = dict(
        revision=model_args.model_revision,
        trust_remote_code=model_args.trust_remote_code,
        attn_implementation=model_args.attn_implementation,
        torch_dtype=model_args.torch_dtype,
        use_cache=True,  # 教师模型仅用于推理,启用缓存加速生成
        device_map=get_kbit_device_map() if quantization_config is not None else None,
        quantization_config=quantization_config,
    )
    # 将教师模型参数传递给训练配置
    training_args.teacher_model_init_kwargs = teacher_model_kwargs

    # 加载分词器(与学生模型匹配,确保格式一致)
    tokenizer = AutoTokenizer.from_pretrained(
        model_args.model_name_or_path,  # 分词器路径(与学生模型相同)
        revision=model_args.model_revision,
        trust_remote_code=model_args.trust_remote_code,
        padding_side="left",  # 左填充(生成任务常用,避免右填充影响生成逻辑)
    )
    # 若分词器未定义pad_token,用eos_token代替(确保填充功能正常)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    ################
    # 数据集加载
    ################
    # 加载指定数据集(如trl-lib/chatbot_arena_completions对话数据集)
    # script_args.dataset_name:数据集名称,script_args.dataset_config:数据集配置(如子数据集)
    dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)

    ################
    # 训练初始化
    ################
    # 初始化GKD训练器,核心组件
    trainer = GKDTrainer(
        model=model_args.model_name_or_path,  # 学生模型路径(如Qwen/Qwen2-0.5B-Instruct)
        teacher_model=training_args.teacher_model_name_or_path,  # 教师模型路径(如Qwen/Qwen2-1.5B-Instruct)
        args=training_args,  # 训练配置(学习率、批次大小等)
        train_dataset=dataset[script_args.dataset_train_split],  # 训练集(如dataset["train"])
        # 验证集(若启用评估)
        eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
        processing_class=tokenizer,  # 用于数据预处理的分词器
        peft_config=get_peft_config(model_args),  # PEFT配置(如LoRA参数:r=64, alpha=16)
    )

    # 若启用评估策略(如每轮评估),配置生成参数并添加回调
    if training_args.eval_strategy != "no":
        # 定义生成配置(控制模型生成行为)
        generation_config = GenerationConfig(
            max_new_tokens=training_args.max_new_tokens,  # 最大生成长度
            do_sample=True,  # 启用采样(而非贪心生成)
            temperature=training_args.temperature  # 温度参数(控制生成多样性,值越大越随机)
        )
        # 初始化回调函数:记录评估时的生成结果(如保存8个示例)
        completions_callback = LogCompletionsCallback(trainer, generation_config, num_prompts=8)
        # 向训练器添加回调
        trainer.add_callback(completions_callback)

    # 启动训练
    trainer.train()

    # 保存模型到输出目录
    trainer.save_model(training_args.output_dir)
    # 若启用push_to_hub,将模型推送到Hugging Face Hub
    if training_args.push_to_hub:
        trainer.push_to_hub(dataset_name=script_args.dataset_name)

GKDTrainer

import os
import random
import textwrap
from typing import Any, Callable, Optional, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from datasets import Dataset
from transformers import (
    AutoModelForCausalLM,  # 用于加载因果语言模型(如GPT类模型)
    BaseImageProcessor,  # 图像处理基类(此处未直接使用,为兼容多模态预留)
    DataCollator,  # 数据整理器,用于批量处理数据
    FeatureExtractionMixin,  # 特征提取混入类(兼容多模态)
    GenerationConfig,  # 生成配置,控制模型生成行为(如长度、温度等)
    PreTrainedModel,  # 预训练模型基类
    PreTrainedTokenizerBase,  # 预训练分词器基类
    ProcessorMixin,  # 处理器混入类(兼容多模态处理器)
    is_wandb_available,  # 检查是否安装wandb(实验跟踪工具)
)
from transformers.trainer_callback import TrainerCallback  # 训练回调基类
from transformers.trainer_utils import EvalPrediction  # 评估预测结果格式
from transformers.utils import is_peft_available  # 检查是否安装PEFT(参数高效微调工具)

from ..models import prepare_deepspeed  # 准备Deepspeed配置(分布式训练)
from ..models.utils import unwrap_model_for_generation  # 为生成任务解包模型(如处理PEFT包装)
from .gkd_config import GKDConfig  # GKD训练的核心配置类
from .sft_trainer import SFTTrainer  # 监督微调训练器(GKDTrainer的父类)
from .utils import (
    DataCollatorForChatML,  # 针对ChatML格式的数据集整理器
    disable_dropout_in_model,  # 禁用模型中的dropout层(稳定训练)
    empty_cache,  # 清空GPU缓存(节省显存)
    generate_model_card,  # 生成模型卡片(README.md)
    get_comet_experiment_url,  # 获取Comet实验跟踪URL(若使用)
)


# 条件导入:仅当PEFT库可用时导入PeftConfig
if is_peft_available():
    from peft import PeftConfig

# 条件导入:仅当wandb可用时导入wandb(实验跟踪)
if is_wandb_available():
    import wandb


class GKDTrainer(SFTTrainer):
    """
    广义知识蒸馏(Generalized Knowledge Distillation)训练器,继承自监督微调训练器(SFTTrainer)。
    核心功能:通过教师模型指导学生模型训练,结合动态生成样本(on-policy学习)和广义JSD损失,提升小模型性能。
    """
    _tag_names = ["trl", "gkd"]  # 模型卡片标签

    def __init__(
        self,
        model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,  # 学生模型(可传入路径或实例)
        teacher_model: Union[PreTrainedModel, nn.Module, str] = None,  # 教师模型(可传入路径或实例)
        args: Optional[GKDConfig] = None,  # GKD训练配置(含蒸馏参数、训练超参等)
        data_collator: Optional[DataCollator] = None,  # 数据整理器(默认为ChatML格式)
        train_dataset: Optional[Dataset] = None,  # 训练数据集
        eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,  # 评估数据集
        processing_class: Optional[
            Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
        ] = None,  # 数据处理器(通常为分词器)
        compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,  # 评估指标计算函数
        callbacks: Optional[list[TrainerCallback]] = None,  # 训练回调(如日志、早停等)
        optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),  # 优化器和学习率调度器
        preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,  # 处理logits用于计算指标
        peft_config: Optional["PeftConfig"] = None,  # PEFT配置(如LoRA参数)
        formatting_func: Optional[Callable] = None,  # 数据格式化函数(将样本转为模型输入格式)
    ):
        # 禁用自动移除未使用的列(因GKD需要"prompts"等额外字段)
        args.remove_unused_columns = False
        # 初始化数据整理器:使用ChatML格式(适合对话模型),限制最大长度
        data_collator = DataCollatorForChatML(tokenizer=processing_class, max_length=args.max_length)

        # 调用父类(SFTTrainer)的初始化方法,完成基础训练器配置
        super().__init__(
            model,
            args=args,
            data_collator=data_collator,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            processing_class=processing_class,
            compute_metrics=compute_metrics,
            callbacks=callbacks,
            optimizers=optimizers,
            preprocess_logits_for_metrics=preprocess_logits_for_metrics,
            peft_config=peft_config,
            formatting_func=formatting_func,
        )

        # 处理教师模型的初始化参数
        if args.teacher_model_init_kwargs is None:
            teacher_model_init_kwargs = {}  # 无参数时使用空字典
        elif not isinstance(teacher_model, str):
            # 若教师模型已实例化,则不允许传入初始化参数(避免冲突)
            raise ValueError(
                "已传入实例化的teacher_model,但同时指定了teacher_model_init_kwargs,两者冲突。"
            )
        else:
            teacher_model_init_kwargs = args.teacher_model_init_kwargs
            # 处理数据类型参数(将字符串转为torch dtype,如"float16"→torch.float16)
            teacher_model_init_kwargs["torch_dtype"] = (
                teacher_model_init_kwargs["torch_dtype"]
                if teacher_model_init_kwargs["torch_dtype"] in ["auto", None]
                else getattr(torch, teacher_model_init_kwargs["torch_dtype"])
            )

        # 若教师模型是路径字符串,则加载预训练模型
        if isinstance(teacher_model, str):
            teacher_model = AutoModelForCausalLM.from_pretrained(teacher_model, **teacher_model_init_kwargs)

        # 禁用学生模型的dropout层(稳定蒸馏过程,减少随机性)
        if args.disable_dropout:
            disable_dropout_in_model(self.model)

        # 准备教师模型:若启用Deepspeed(分布式训练),则适配Deepspeed;否则用accelerator准备
        if self.is_deepspeed_enabled:
            self.teacher_model = prepare_deepspeed(teacher_model, self.accelerator)
        else:
            # 将教师模型设为评估模式(不训练,仅用于推理)
            self.teacher_model = self.accelerator.prepare_model(teacher_model, evaluation_mode=True)

        # 初始化GKD核心超参数
        self.lmbda = args.lmbda  # 学生自生成样本的概率(on-policy学习概率)
        self.beta = args.beta  # 广义JSD损失的插值系数
        self.temperature = args.temperature  # 概率分布的温度系数(控制平滑度)
        self.seq_kd = args.seq_kd  # 是否强制使用教师生成的序列进行蒸馏

        # 初始化生成配置(控制动态样本生成的行为)
        self.generation_config = GenerationConfig(
            max_new_tokens=args.max_new_tokens,  # 最大生成token数
            temperature=args.temperature,  # 生成温度(值越大越随机)
            do_sample=True,  # 启用采样(而非贪心生成)
            top_k=0,  # 不限制top_k(配合温度控制多样性)
            use_cache=False if args.gradient_checkpointing else True,  # 梯度检查点启用时禁用缓存(冲突)
            pad_token_id=self.processing_class.pad_token_id,  # 填充token ID
        )
        # 适配模型自定义的EOS token(如Llama 3的<|eot_id|>)
        if (
            hasattr(self.model.generation_config, "eos_token_id")
            and self.model.generation_config.eos_token_id is not None
        ):
            self.generation_config.eos_token_id = self.model.generation_config.eos_token_id

    @staticmethod
    def generalized_jsd_loss(
        student_logits,  # 学生模型的logits,形状:(batch_size, seq_len, vocab_size)
        teacher_logits,  # 教师模型的logits,形状同上
        labels=None,  # 标签,形状:(batch_size, seq_len),-100表示padding(忽略)
        beta=0.5,  # 插值系数(控制教师/学生分布权重)
        temperature=1.0,  # 温度系数(软化概率分布)
        reduction="batchmean",  # 损失聚合方式(batchmean/sum/mean)
    ):
        """
        计算广义Jensen-Shannon散度(JSD)损失,用于知识蒸馏。
        参考论文:https://huggingface.co/papers/2306.13649 公式(1)
        """

        # 温度缩放:软化概率分布(温度越高,分布越平滑)
        student_logits = student_logits / temperature
        teacher_logits = teacher_logits / temperature

        # 计算学生和教师的对数概率(log_softmax)
        student_log_probs = F.log_softmax(student_logits, dim=-1)
        teacher_log_probs = F.log_softmax(teacher_logits, dim=-1)

        if beta == 0:
            # beta=0:退化为传统KL散度(学生模仿教师)
            # F.kl_div(input=学生对数概率, target=教师对数概率, log_target=True)
            jsd = F.kl_div(student_log_probs, teacher_log_probs, reduction="none", log_target=True)
        elif beta == 1:
            # beta=1:反向KL散度(教师模仿学生,适合学生容量较小时避免模式崩溃)
            jsd = F.kl_div(teacher_log_probs, student_log_probs, reduction="none", log_target=True)
        else:
            # 混合分布的对数概率:log[(1-beta)*P_student + beta*P_teacher]
            # 等价于log(exp(log(1-beta) + log_P_student) + exp(log(beta) + log_P_teacher))
            beta = torch.tensor(beta, dtype=student_log_probs.dtype)  # 转为tensor(匹配设备和类型)
            mixture_log_probs = torch.logsumexp(
                torch.stack([student_log_probs + torch.log(1 - beta), teacher_log_probs + torch.log(beta)]),
                dim=0,  # 按第0维(学生/教师)求和
            )

            # 计算混合分布与教师/学生分布的KL散度(注意PyTorch的KL顺序与数学定义相反)
            kl_teacher = F.kl_div(mixture_log_probs, teacher_log_probs, reduction="none", log_target=True)
            kl_student = F.kl_div(mixture_log_probs, student_log_probs, reduction="none", log_target=True)

            # 广义JSD:beta*KL(混合||教师) + (1-beta)*KL(混合||学生)
            jsd = beta * kl_teacher + (1 - beta) * kl_student

        # 掩码处理:忽略padding位置(labels=-100)的损失
        if labels is not None:
            mask = labels != -100  # 有效位置为True,padding为False
            jsd = jsd[mask]  # 只保留有效位置的损失

        # 损失聚合(根据reduction参数)
        if reduction == "batchmean":
            # 按有效样本数平均(避免padding影响)
            if labels is not None:
                return jsd.sum() / mask.sum()  # 总损失 / 有效token数
            else:
                return jsd.sum() / (jsd.size(0) * jsd.size(1))  # 总损失 / 总token数(无标签时)
        elif reduction == "sum":
            return jsd.sum()  # 求和
        elif reduction == "mean":
            return jsd.mean()  # 简单平均
        else:
            return jsd  # 不聚合,返回原始损失 tensor

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        """
        计算GKD的损失:通过广义JSD损失让学生模仿教师的输出分布。
        """
        # 学生模型前向传播(获取logits)
        outputs_student = model(
            input_ids=inputs["input_ids"],  # 输入token ID
            attention_mask=inputs["attention_mask"],  # 注意力掩码(0表示padding)
        )

        # 教师模型前向传播(评估模式,不计算梯度)
        self.teacher_model.eval()  # 确保教师模型处于评估模式(禁用dropout等)
        with torch.no_grad():  # 禁用梯度计算(节省显存)
            outputs_teacher = self.teacher_model(
                input_ids=inputs["input_ids"],
                attention_mask=inputs["attention_mask"],
            )

        # 切片处理:只保留生成部分的logits(排除输入prompt部分)
        prompt_lengths = inputs["prompts"].shape[1]  # prompt的长度(输入部分,无需预测)
        # 学生logits:从prompt结束位置的前一个token开始,到序列结束前一个token(因语言模型预测下一个token)
        shifted_student_logits = outputs_student.logits[:, prompt_lengths - 1 : -1, :]
        # 教师logits:同上(与学生对齐)
        shifted_teacher_logits = outputs_teacher.logits[:, prompt_lengths - 1 : -1, :]
        # 标签:从prompt结束位置开始(生成部分的真实标签)
        shifted_labels = inputs["labels"][:, prompt_lengths:]

        # 计算广义JSD损失
        loss = self.generalized_jsd_loss(
            student_logits=shifted_student_logits,
            teacher_logits=shifted_teacher_logits,
            labels=shifted_labels,  # 用于掩码padding
            beta=self.beta,  # 从初始化参数获取
        )

        # 清空GPU缓存(节省显存)
        empty_cache()

        # 返回损失(可选返回学生模型输出)
        return (loss, outputs_student) if return_outputs else loss

    @staticmethod
    def generate_on_policy_outputs(model, inputs, generation_config, pad_token_id=None):
        """
        生成动态样本(on-policy学习用):基于输入prompt生成输出序列,作为新的训练数据。
        """
        # 基于prompt生成输出(仅用prompt作为输入,不包含原始标签)
        generated_outputs = model.generate(
            input_ids=inputs["prompts"],  # 输入prompt的token ID
            attention_mask=inputs.get("prompt_attention_mask", None),  # prompt的注意力掩码
            generation_config=generation_config,  # 生成配置(长度、温度等)
            return_dict_in_generate=True,  # 返回详细生成结果(含序列、分数等)
        )

        # 获取生成的token ID序列
        generated_tokens = generated_outputs.sequences
        # 初始化新的注意力掩码(全1,后续修正padding位置)
        new_attention_mask = torch.ones_like(generated_tokens)
        # 新标签:复制生成的token(后续修正padding位置)
        new_labels = generated_tokens.clone()

        # 处理padding token(若指定)
        if pad_token_id is not None:
            # 标签中padding位置设为-100(忽略损失)
            new_labels[new_labels == pad_token_id] = -100
            # 注意力掩码中padding位置设为0(不参与注意力计算)
            new_attention_mask[generated_tokens == pad_token_id] = 0

        # 返回生成的输入ID、注意力掩码、标签
        return generated_tokens, new_attention_mask, new_labels

    def training_step(
        self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None
    ) -> torch.Tensor:
        """
        单步训练:实现on-policy学习,动态生成样本用于训练。
        逻辑:以概率lmbda使用学生自生成样本,或强制使用教师生成样本(seq_kd=True)。
        """
        if self.seq_kd:
            # seq_kd=True:强制使用教师模型生成样本(适合初始训练阶段,学习教师的"正确"输出)
            # 解包教师模型(处理PEFT/分布式包装)
            with unwrap_model_for_generation(self.teacher_model, self.accelerator) as unwrapped_model:
                # 生成教师样本
                new_input_ids, new_attention_mask, new_labels = self.generate_on_policy_outputs(
                    unwrapped_model, inputs, self.generation_config, self.processing_class.pad_token_id
                )
            # 更新输入为教师生成的样本
            inputs["input_ids"] = new_input_ids
            inputs["attention_mask"] = new_attention_mask
            inputs["labels"] = new_labels
        # 以概率lmbda使用学生自生成样本(on-policy学习核心)
        if random.random() <= self.lmbda:
            # 解包学生模型(处理PEFT/分布式包装)
            with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
                # 生成学生样本
                new_input_ids, new_attention_mask, new_labels = self.generate_on_policy_outputs(
                    unwrapped_model, inputs, self.generation_config, self.processing_class.pad_token_id
                )
            # 更新输入为学生生成的样本(让学生从自身生成的结果中学习)
            inputs["input_ids"] = new_input_ids
            inputs["attention_mask"] = new_attention_mask
            inputs["labels"] = new_labels

        # 调用父类的training_step计算损失并更新参数
        loss = super().training_step(model, inputs, num_items_in_batch)
        return loss

    def create_model_card(
        self,
        model_name: Optional[str] = None,  # 模型名称
        dataset_name: Optional[str] = None,  # 训练数据集名称
        tags: Union[str, list[str], None] = None,  # 模型标签
    ):
        """
        生成模型卡片(README.md),包含训练信息、引用、标签等,方便上传到Hugging Face Hub。
        """
        # 仅在主进程执行(避免多进程重复生成)
        if not self.is_world_process_zero():
            return

        # 确定基座模型名称(若模型从预训练模型微调而来)
        if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
            base_model = self.model.config._name_or_path
        else:
            base_model = None

        # 标准化标签(转为集合避免重复)
        if tags is None:
            tags = set()
        elif isinstance(tags, str):
            tags = {tags}
        else:
            tags = set(tags)

        # 若使用unsloth加速训练,添加对应标签
        if hasattr(self.model.config, "unsloth_version"):
            tags.add("unsloth")

        # 添加默认标签(trl和gkd)
        tags.update(self._tag_names)

        # GKD论文引用格式
        citation = textwrap.dedent("""\
        @inproceedings{agarwal2024on-policy,
            title        = {{On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes}},
            author       = {Rishabh Agarwal and Nino Vieillard and Yongchao Zhou and Piotr Stanczyk and Sabela Ramos Garea and Matthieu Geist and Olivier Bachem},
            year         = 2024,
            booktitle    = {The Twelfth International Conference on Learning Representations, {ICLR} 2024, Vienna, Austria, May 7-11, 2024},
            publisher    = {OpenReview.net},
            url          = {https://openreview.net/forum?id=3zKtaqxLhW},
        }""")

        # 生成模型卡片内容
        model_card = generate_model_card(
            base_model=base_model,  # 基座模型
            model_name=model_name,  # 模型名称
            hub_model_id=self.hub_model_id,  # Hub上的模型ID
            dataset_name=dataset_name,  # 训练数据集
            tags=tags,  # 标签
            wandb_url=wandb.run.url if is_wandb_available() and wandb.run is not None else None,  # wandb实验URL
            comet_url=get_comet_experiment_url(),  # Comet实验URL
            trainer_name="GKD",  # 训练器名称
            trainer_citation=citation,  # 引用
            paper_title="On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes",  # 论文标题
            paper_id="2306.13649",  # 论文ID(arXiv或OpenReview)
        )

        # 保存模型卡片到输出目录
        model_card.save(os.path.join(self.args.output_dir, "README.md"))

GKDConfig

from dataclasses import dataclass, field
from typing import Any, Optional

from transformers import TrainingArguments  # 导入Hugging Face的训练参数基类

from .sft_config import SFTConfig  # 导入监督微调(SFT)的配置类(GKDConfig的父类)


@dataclass
class GKDConfig(SFTConfig):
    """
    广义知识蒸馏(Generalized Knowledge Distillation, GKD)训练器的配置类。
    
    此类仅包含GKD训练特有的参数,完整的训练参数请参考`transformers.TrainingArguments`和`SFTConfig`的文档。
    
    参数说明:
        temperature (`float`, 可选, 默认值 `0.9`):
            采样温度。温度越高,生成的结果随机性越强。
        lmbda (`float`, 可选, 默认值 `0.5`):
            控制学生自生成样本比例的参数(即on-policy学习中,使用学生自己生成的输出进行训练的比例)。
        beta (`float`, 可选, 默认值 `0.5`):
            广义Jensen-Shannon散度(JSD)损失的插值系数,范围在`0.0`到`1.0`之间。
            当beta=0.0时,损失退化为传统KL散度(学生模仿教师);当beta=1.0时,损失为反向KL散度(教师模仿学生)。
        max_new_tokens (`int`, 可选, 默认值 `128`):
            每次生成的最大token数量。
        teacher_model_name_or_path (`str` 或 `None`, 可选, 默认值 `None`):
            教师模型的名称或路径。若为`None`,则教师模型与当前训练的模型相同。
        teacher_model_init_kwargs (`dict[str, Any]` 或 `None`, 可选, 默认值 `None`):
            从字符串实例化教师模型时,传递给`AutoModelForCausalLM.from_pretrained`的关键字参数。
        disable_dropout (`bool`, 可选, 默认值 `True`):
            是否禁用模型中的dropout层(蒸馏中常用,以减少随机性,稳定训练)。
        seq_kd (`bool`, 可选, 默认值 `False`):
            是否执行序列级蒸馏(Sequence-Level KD),可视为在教师生成的输出上进行监督微调。
    """

    # 扩展有效字典字段:在TrainingArguments的基础上添加教师模型的初始化参数
    _VALID_DICT_FIELDS = TrainingArguments._VALID_DICT_FIELDS + ["teacher_model_init_kwargs"]

    temperature: float = field(
        default=0.9,
        metadata={"help": "采样温度。温度越高,生成的结果随机性越强。"},
    )
    lmbda: float = field(
        default=0.5,
        metadata={
            "help": "控制学生自生成样本比例的参数(即on-policy学习中,使用学生自己生成的输出进行训练的比例)。"
        },
    )
    beta: float = field(
        default=0.5,
        metadata={
            "help": "广义Jensen-Shannon散度(JSD)损失的插值系数,范围在0.0到1.0之间。"
            "当beta=0.0时,损失为KL散度(学生模仿教师);当beta=1.0时,损失为反向KL散度(教师模仿学生)。"
        },
    )
    max_new_tokens: int = field(
        default=128,
        metadata={"help": "每次生成的最大token数量。"},
    )
    teacher_model_name_or_path: Optional[str] = field(
        default=None,
        metadata={
            "help": "教师模型的名称或路径。若为None,教师模型将与当前训练的模型相同。"
        },
    )
    teacher_model_init_kwargs: Optional[dict[str, Any]] = field(
        default=None,
        metadata={
            "help": "从字符串实例化教师模型时,传递给AutoModelForCausalLM.from_pretrained的关键字参数。"
        },
    )
    disable_dropout: bool = field(
        default=True,
        metadata={"help": "是否禁用模型中的dropout层(蒸馏中常用以稳定训练)。"},
    )
    seq_kd: bool = field(
        default=False,
        metadata={
            "help": "是否执行序列级蒸馏(可视为在教师生成的输出上进行监督微调)。"
        },
    )

    def __post_init__(self):
        """初始化后执行的方法:调用父类初始化逻辑,并验证参数合法性。"""
        super().__post_init__()  # 调用父类(SFTConfig)的初始化后处理逻辑
        # 验证lmbda参数是否在[0, 1]范围内
        if self.lmbda < 0.0 or self.lmbda > 1.0:
            raise ValueError("lmbda参数必须在[0.0, 1.0]范围内。")
        # 验证beta参数是否在[0, 1]范围内
        if self.beta < 0.0 or self.beta > 1.0:
            raise ValueError("beta参数必须在[0.0, 1.0]范围内。")

网站公告

今日签到

点亮在社区的每一天
去签到