CleanGPT:清晰简洁的GPT模型训练框架

发布于:2025-03-13 ⋅ 阅读:(17) ⋅ 点赞:(0)
  • 本文介绍我近期的开源项目 CleanGPT,这是一个基于PyTorch实现的GPT类模型训练框架。该项目试图保持清晰、简洁、扩展性和教育性,旨在为科研工作提供一个易于使用的工程模板,基于 NanoGPT 扩展实现
  • 本文主要介绍该项目中使用的各种trick

0. 项目结构

  • 本项目结构如下所示
    在这里插入图片描述
  • 项目主要文件存储于 data, train, eval 和 model 四个文件夹中,具体而言
    1. data:用于存储和数据相关的所有文件,其中 data/data.py 定义了 Dataset, Dataloader, DistributedSampler 等方法,用于在训练和评估过程中加载数据。每个数据集以独立文件夹形式存储和管理,其中提供了数据集和Tokenizer等相关组件的下载方法,目前提供了 shakespeare_char 和 tinystory 两个数据集
      在这里插入图片描述
    2. model:用于存储模型定义代码,目前在 model/NanoGPT.py 中实现了 NanoGPT 模型
    3. train:用于存储训练相关的代码,其中 train/config.py 使用 argparse 整理了训练所需的所有参数并设置了缺省值和解释;train/scheduler.py 实现了训练调度器,可用于动态调整 learning_rate, batch_size 和 weight_decay 系数等参数,并实现了早停方法;train/train_ddp.py 是启动分布式数据并行训练的入口;trainer.py 实现了完整的训练流程,包括 checkpoint & snapshot 存储、wandb log 等功能
      在这里插入图片描述
    4. eval:用于存储模型评估相关的代码,目前在 eval/text_autoregress.py 中实现了简单的自回归生成方法

1. 特性

  • 本项目具备以下特性,基本可以覆盖典型科研项目的需求
    1. 分布式训练:支持基于 PyTorch DDP 的多卡训练框架(不支持其他训练方式,如 CPU 训练)
    2. 自动混合精度:支持基于 torch.cuda.amp 的混合精度训练(优先尝试 bf16,若不支持则使用 fp16)
    3. 模型编译加速:支持利用 torch.compile 对模型进行编译优化从而加速训练(要求 Pytorch 2.0 及以上版本)
    4. 轻量数据加载:利用 np.memmap 构造 Dataloader,不需要将全部数据加载到内存
    5. 训练调度器:提供了强大的训练调度器,支持 learning-rate, weight decay coefficient 和 training batch-size 的动态调度,使用早停机制避免过拟合
    6. 断点续训:支持从最新的 snapshot 无感恢复训练过程
    7. 模型管理:提供了实用的 checkpoint 保存管理机制,可根据设定自动保存最好(即验证损失最低)的n个模型权重,且可从指定 checkpoint 初始化进行微调
    8. Wandb Log:支持在 Wandb 实时记录训练损失、验证损失、学习率、数据集访问比例等数据曲线
    9. Macro Batch:由于 Lanuage Model 训练往往使用非常大的数集,整个训练过程可能只遍历数据集几次,甚至无法完整遍历一次,传统的 epoch 概念不再适用。本项目基于 macro-batch 概念进行训练,具体地,batch 是加载数据的最小单位,若干个 batch 组成一个 macro-batch,作为验证损失评估、snapshot & checkpoint 保存的单位
    10. GPT2: 支持加载 HuggingFace GPT-2 checkpoints 作为初始模型进行微调

2. 实现细节

  • 代码注释已经比较详细了,本节待续

网站公告

今日签到

点亮在社区的每一天
去签到