- 本文介绍我近期的开源项目 CleanGPT,这是一个基于PyTorch实现的GPT类模型训练框架。该项目试图保持清晰、简洁、扩展性和教育性,旨在为科研工作提供一个易于使用的工程模板,基于 NanoGPT 扩展实现
- 本文主要介绍该项目中使用的各种trick
0. 项目结构
- 本项目结构如下所示

- 项目主要文件存储于 data, train, eval 和 model 四个文件夹中,具体而言
- data:用于存储和数据相关的所有文件,其中
data/data.py
定义了 Dataset
, Dataloader
, DistributedSampler
等方法,用于在训练和评估过程中加载数据。每个数据集以独立文件夹形式存储和管理,其中提供了数据集和Tokenizer等相关组件的下载方法,目前提供了 shakespeare_char 和 tinystory 两个数据集

- model:用于存储模型定义代码,目前在
model/NanoGPT.py
中实现了 NanoGPT 模型
- train:用于存储训练相关的代码,其中
train/config.py
使用 argparse
整理了训练所需的所有参数并设置了缺省值和解释;train/scheduler.py
实现了训练调度器,可用于动态调整 learning_rate, batch_size 和 weight_decay 系数等参数,并实现了早停方法;train/train_ddp.py
是启动分布式数据并行训练的入口;trainer.py
实现了完整的训练流程,包括 checkpoint & snapshot 存储、wandb log 等功能

- eval:用于存储模型评估相关的代码,目前在
eval/text_autoregress.py
中实现了简单的自回归生成方法
1. 特性
- 本项目具备以下特性,基本可以覆盖典型科研项目的需求
- 分布式训练:支持基于 PyTorch DDP 的多卡训练框架(不支持其他训练方式,如 CPU 训练)
- 自动混合精度:支持基于
torch.cuda.amp
的混合精度训练(优先尝试 bf16,若不支持则使用 fp16)
- 模型编译加速:支持利用
torch.compile
对模型进行编译优化从而加速训练(要求 Pytorch 2.0 及以上版本)
- 轻量数据加载:利用
np.memmap
构造 Dataloader,不需要将全部数据加载到内存
- 训练调度器:提供了强大的训练调度器,支持 learning-rate, weight decay coefficient 和 training batch-size 的动态调度,使用早停机制避免过拟合
- 断点续训:支持从最新的 snapshot 无感恢复训练过程
- 模型管理:提供了实用的 checkpoint 保存管理机制,可根据设定自动保存最好(即验证损失最低)的n个模型权重,且可从指定 checkpoint 初始化进行微调
- Wandb Log:支持在 Wandb 实时记录训练损失、验证损失、学习率、数据集访问比例等数据曲线
- Macro Batch:由于 Lanuage Model 训练往往使用非常大的数集,整个训练过程可能只遍历数据集几次,甚至无法完整遍历一次,传统的 epoch 概念不再适用。本项目基于 macro-batch 概念进行训练,具体地,batch 是加载数据的最小单位,若干个 batch 组成一个 macro-batch,作为验证损失评估、snapshot & checkpoint 保存的单位
- GPT2: 支持加载 HuggingFace GPT-2 checkpoints 作为初始模型进行微调
2. 实现细节