复现nanoGPT——train.py(详细版拆解)

发布于:2025-07-02 ⋅ 阅读:(19) ⋅ 点赞:(0)

原版的train前面有特别多的参数定义,看的人头晕,所以我就把它们系统化的整理出来,分成几个模块,和各自使用部分放在一起,但是事实证明,放在最前面是最好的,因为放在中间可能会涉及到参数调用顺序和覆盖问题。把它们分开只是为了方便理解。

最前面应该还有两个参数out和device,因为它们几乎所有模块都有,就放在最前面了

out_dir = 'out'
device = 'cuda'

1.生成训练\测试数据

#--------------生成数据---------------------------------
dataset = 'hongloumeng_char'
data_dir = os.path.join('data', dataset)
device_type = 'cuda' if 'cuda' in device else 'cpu'
batch_size = 64
def get_batch(split):
    if split == 'train':
        data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')
    else:
        data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r')
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix])
    y = torch.stack([torch.from_numpy((data[i+1:i+block_size+1]).astype(np.int64))for i in ix])

    if device_type == 'cuda':
        x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
    else:
        x, y = x.to(device), y.to(device)
    return x, y

2.处理学习率

#---------------处理学习率-----------------------------------
warm_up_iter = 100
learning_rate = 1e-3
lr_decay_iter = 5000
min_lr =1e-4
decay_lr = True
def get_lr(it):
    if it < warm_up_iter:
        return learning_rate * (it + 1) / (warm_up_iter + 1)
    elif it > lr_decay_iter:
        return min_lr
    else:
        lr_ratio = (it - warm_up_iter) / (lr_decay_iter - warm_up_iter)
        coeff = 0.5 * (math.cos(lr_ratio * math.pi) + 1)
        return min_lr + coeff * (learning_rate - min_lr)

3.模型初始化

#--------------初始化模型--------------------------------
"""
class Config:
    block_size: int = 1024
    vocab_size: int = 50304
    n_embd: int = 768
    bias: bool = True
    n_layer: int = 12
    n_head: int = 12
    dropout: float = 0.0
#模型定义时的config
"""


block_size = 256
n_layer = 6
n_head = 6
n_embd = 384
dropout = 0.2
bias = False
init_form = "scratch"
meta_vocab_size = None

model_args = dict(block_size=block_size, vocab_size=None, n_embd=n_embd,
                 bias=bias, n_layer=n_layer, n_head=n_head, dropout=dropout)

iter_num = 0
best_val_loss = 1e9

meta_path = os.path.join(data_dir, 'meta.pkl')
if os.path.exists(meta_path):
    with open(meta_path, 'rb') as f:
        meta = pickle.load(f)
    meta_vocab_size = meta['vocab_size']
    print(f"found vacab_size={meta_vocab_size}(insida{meta_path})")

if init_form == "scratch":
    print("Initializing a new model from scratch")
    if meta_vocab_size is None:
        print("defaulting to vacab_size of GPT-2 to 50304")
    model_args['vocab_size'] = meta_vocab_size if meta_vocab_size is not None else 50304
    gptconf = GPTConfig(**model_args)
    model = GPT(gptconf)
elif init_form == "resume":
    out_dir = 'out-hongloumeng-char'
    print(f"Resuming training from {out_dir}")
    ckpt_path = os.path.join(out_dir, "ckpt.pt")
    checkpoint = torch.load(ckpt_path, map_location=device)
    checkpoint_conf_arg = checkpoint['model_args']
    for k in ['block_size', 'vocab_size', 'n_embd', 'bias', 'n_layer',
              'n_head']:
        model_args[k] = checkpoint_conf_arg[k]
    gptconf = GPTConfig(**model_args)
    model = GPT(gptconf)

    #一个不知道为什么的错误
    state_dict = checkpoint['model']
    unwanted_prefix = '_orig_mod.'
    for k, v in list(state_dict.items()):
        if k.startswith(unwanted_prefix):
            state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
    model.load_state_dict(state_dict)
    iter_num = checkpoint['iter_num']
    best_val_loss = checkpoint['best_val_loss']

elif init_form.startswith('gpt2'):
    print(f"Initializing from OpenAI GPT-2 weights:{init_form}")
    override_args = dict(dropout=dropout)
    model = GPT.from_pretrained(init_form, override_args)
    for k in ['block_size', 'vocab_size', 'n_embd', 'bias', 'n_layer',
              'n_head']:
        model_args[k] = getattr(model.config, k)
if block_size <= model.config.block_size:
    model.crop_block_size(block_size)
    model_args['block_size'] = block_size
model.to(device)

4.编译

#--------------编译-------------------------------------
compile = True
if compile:
    print("compiling the model ...(take a ~minute)")
    unoptimized_model = model
    model = torch.compile(model)

但是我发现貌似compile只能在UNIX系统上使用,因为涉及到一个triton库好像只有UNIX版本的,如果遇到相关报错,可以在开头加:

#import torch._dynamo
#torch._dynamo.config.suppress_errors = True

5.optimizer设置

#---------------optimizer----------------------------------
weight_decay = 1e-1
beta1 = 0.9
beta2 = 0.95
#beta2 = 0.99
optimizer = model.configure_optimizers(weight_decay, learning_rate, (beta1, beta2), device_type)
if init_form == 'resume':
    optimizer.load_state_dict(checkpoint['optimizer'])
checkpoint = None

6.DDP

#---------------DDP(和master_process相关)-------------------
gradient_accumulation_steps = 1
backend = 'nccl'
ddp = int(os.environ.get('RANK', -1)) != -1
if ddp:
    init_process_group(backend=backend)
    ddp_rank = int(os.environ['RANK'])
    ddp_local_rank = int(os.environ['LOCAL_RANK'])
    ddp_world_size = int(os.environ['WORLD_SIZE'])
    device = f'cuda:{ddp_local_rank}'
    torch.cuda.set_device(device)
    master_process = ddp_rank == 0
    seed_offset = ddp_rank
    assert gradient_accumulation_steps % ddp_world_size == 0
    gradient_accumulation_steps //= ddp_world_size
else:
    master_process = True
    seed_offset = 0
    ddp_world_size = 1
tokens_per_iter = gradient_accumulation_steps * ddp_world_size * batch_size * block_size
print(f"tokens per iteration will be :{tokens_per_iter}")

if master_process: #避免多个进程创建
    os.makedirs(out_dir, exist_ok=True)
torch.manual_seed(1337 + seed_offset)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

if ddp:
    model = DDP(model, device_ids=[ddp_local_rank])

7.计算损失

#--------------计算损失---------------------------------
from contextlib import nullcontext
dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16'
ptdtype = {'float32':torch.float32, 'bfloat16':torch.bfloat16, 'float16':torch.float16}[dtype]
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
eval_iters = 200
@torch.no_grad
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            with ctx:
                logits, loss = model(X,Y)
            losses[k] = loss.item()
        out[split] = losses.mean().item()
    model.train()
    return out

8.主干训练部分

#---------------主干部分-------------------------------------
eval_interval = 250
#参数配置
config_keys = [k for k, v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
config = {k: globals()[k] for k in config_keys}
#日志记录
wandb_log = False
wandb_project = 'hongloumeng-char'
wandb_run_name = 'mini-gpt' # 'run' + str(time.time())
always_keep_checkpoint = False
running_mfu = -1.0
raw_model = model.module if ddp else model
if wandb_log:
    import wandb
    wandb.init(project=wandb_project, name=wandb_run_name, config=config)
eval_only = False

X, Y = get_batch('train')

scaler = torch.amp.GradScaler('cuda', enabled=(dtype=='float16'))
grad_clip = 1.0

t0 = time.time()

log_interval = 10
local_iter_num = 0
max_iters = 5000
while(True):
    lr = get_lr(iter_num) if decay_lr else learning_rate
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    #保存日志
    if iter_num % eval_interval == 0 and master_process:
        loss = estimate_loss()
        print(f"step {iter_num} training loss:{loss['train']:.4f}, val loss:{loss['val']:.4f}")
        if wandb_log:
            wandb.log({
                'iter':iter_num,
                'loss/val':loss['val'],
                'loss/train':loss['train'],
                'lr':lr,
                'mfu':running_mfu * 100
            })
        if loss['val'] <= best_val_loss or always_keep_checkpoint:
            best_val_loss =  loss['val']
            if iter_num > 0:
                checkpoint = {
                    'model':raw_model.state_dict(),
                    'optimizer':optimizer.state_dict(),
                    'model_args':model_args,
                    'iter_num':iter_num,
                    'best_val_loss':best_val_loss,
                    'config':config,
                }
                print(f"saving checkpoint to {out_dir}")
                torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt'))
    if iter_num == 0 and eval_only:
        break
    #梯度裁剪
    for micro_step in range(gradient_accumulation_steps):
        if ddp:
            model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1)
        with ctx:
            logits, loss = model(X, Y)
            loss = loss / gradient_accumulation_steps
        X, Y = get_batch('train')
        scaler.scale(loss).backward()

    if grad_clip != 0.0:
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
    scaler.step(optimizer)
    scaler.update()
    optimizer.zero_grad(set_to_none=True)


    #评估性质mfu
    t1 = time.time()
    dt = t1 - t0
    t0 = t1
    if iter_num % log_interval == 0 and master_process:
        lossf = loss.item() * gradient_accumulation_steps
        if local_iter_num >=5:
            mfu = raw_model.estimate_mfu(batch_size * gradient_accumulation_steps, dt)
            running_mfu = mfu if running_mfu == -1.0 else 0.9*running_mfu + 0.1*mfu
        print(f"iter{iter_num}:loss{lossf:.4f}, time {dt*1000:.2f}ms, mfu {running_mfu*100:.2f}%")
    iter_num +=1
    local_iter_num += 1

    if iter_num > max_iters:
        break

if ddp:
    destroy_process_group()

实现的时候可以先写骨干,根据骨干所需要的功能再逐渐补充上面的


网站公告

今日签到

点亮在社区的每一天
去签到