2025-10-13 06:49:24 -07:00
"""
Midtrain the model. Same as pretraining but simpler.
Run as:
python -m scripts.mid_train
Or torchrun for training:
torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --device_batch_size=16
"""
from collections import deque
import os
os . environ [ " PYTORCH_CUDA_ALLOC_CONF " ] = " expandable_segments:True "
import time
import wandb
import torch
from nanochat . common import compute_init , compute_cleanup , print0 , DummyWandb , get_base_dir
from nanochat . tokenizer import get_token_bytes
from nanochat . checkpoint_manager import save_checkpoint
from nanochat . loss_eval import evaluate_bpb
from nanochat . checkpoint_manager import load_model
import torch . distributed as dist
from tasks . common import TaskMixture
from tasks . gsm8k import GSM8K
from tasks . mmlu import MMLU
from tasks . smoltalk import SmolTalk
# -----------------------------------------------------------------------------
run = " dummy " # wandb run name default ("dummy" is special - we won't log to wandb)
model_tag = None # model tag to load the model from (base model or midtrained model)
step = None # step to load the model from (base model or midtrained model)
dtype = " bfloat16 "
max_seq_len = 2048
device_batch_size = 32
unembedding_lr = 0.004
embedding_lr = 0.2
matrix_lr = 0.02
init_lr_frac = 1.0 # initial learning rate is this fraction of the base learning rate
weight_decay = 0.0
eval_every = 150
eval_tokens = 20 * 524288
total_batch_size = 524288
2025-10-15 16:35:04 +00:00
dry_run = 0 # dry_run=1 is for experiments: we will log to wandb but we won't write checkpoints or report
2025-10-13 06:49:24 -07:00
config_keys = [ k for k , v in globals ( ) . items ( ) if not k . startswith ( ' _ ' ) and isinstance ( v , ( int , float , bool , str ) ) ]
exec ( open ( os . path . join ( ' nanochat ' , ' configurator.py ' ) ) . read ( ) ) # overrides from command line or config file
user_config = { k : globals ( ) [ k ] for k in config_keys } # possibly useful for logging
# -----------------------------------------------------------------------------
# Compute init
ddp , ddp_rank , ddp_local_rank , ddp_world_size , device = compute_init ( )
master_process = ddp_rank == 0
dtype = torch . float32 if dtype == ' float32 ' else torch . bfloat16
autocast_ctx = torch . amp . autocast ( device_type = " cuda " , dtype = dtype )
# wandb logging init
use_dummy_wandb = run == " dummy " or not master_process
wandb_run = DummyWandb ( ) if use_dummy_wandb else wandb . init ( project = " nanochat-mid " , name = run , config = user_config )
# Load the model and tokenizer
model , tokenizer , meta = load_model ( " base " , device , phase = " train " , model_tag = model_tag , step = step )
pretrain_batch_size = meta . get ( " device_batch_size " , None )
if pretrain_batch_size is not None and device_batch_size > pretrain_batch_size :
print0 ( f " FOOTGUN WARNING: base model training used device_batch_size { pretrain_batch_size } , did you pass in a good --device_batch_size to this script? " )
orig_model = model
model = torch . compile ( model , dynamic = False )
depth = model . config . n_layer
num_flops_per_token = model . estimate_flops ( )
tokens_per_fwdbwd = device_batch_size * max_seq_len # tokens per iteration for a single rank
world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks
assert total_batch_size % world_tokens_per_fwdbwd == 0
grad_accum_steps = total_batch_size / / world_tokens_per_fwdbwd
print0 ( f " Tokens / micro-batch / rank: { device_batch_size } x { max_seq_len } = { tokens_per_fwdbwd : , } " )
print0 ( f " Tokens / micro-batch: { world_tokens_per_fwdbwd : , } " )
print0 ( f " Total batch size { total_batch_size : , } => gradient accumulation steps: { grad_accum_steps } " )
token_bytes = get_token_bytes ( device = device )
# Initialize the Optimizer (Muon for Linear layers, AdamW for embedding and lm_head)
optimizers = model . setup_optimizers ( unembedding_lr = unembedding_lr , embedding_lr = embedding_lr , matrix_lr = matrix_lr , weight_decay = weight_decay )
adamw_optimizer , muon_optimizer = optimizers
# Override the initial learning rate as a fraction of the base learning rate
for opt in optimizers :
for group in opt . param_groups :
group [ " lr " ] = group [ " lr " ] * init_lr_frac
group [ " initial_lr " ] = group [ " lr " ] # save the initial learning so we can decay easily later
# Midtraining data mixture and DataLoader
base_dir = get_base_dir ( )
train_dataset = TaskMixture ( [
SmolTalk ( split = " train " ) , # 460K rows of general conversations
MMLU ( subset = " auxiliary_train " , split = " train " ) , # 100K rows of multiple choice problems drawn from ARC, MC_TEST, OBQA, RACE
GSM8K ( subset = " main " , split = " train " ) , # 8K rows teaching simple math and (calculator) tool use
] ) # total: 460K + 100K + 8K = 568K rows
val_dataset = TaskMixture ( [
SmolTalk ( split = " test " ) , # 24K rows in test set
MMLU ( subset = " all " , split = " test " , stop = 5200 ) , # 14K rows in test set, use only 5.2K to match the train ratios
GSM8K ( subset = " main " , split = " test " , stop = 420 ) , # 1.32K rows in test set, use only 420 to match the train ratios
] ) # total: 24K + 14K + 1.32K ~= 39K rows
# DataLoader is defined here, it emits inputs, targets : 2D tensors of shape (device_batch_size, max_seq_len)
# A big problem is that we don't know the final num_iterations in advance. So we create
# these two global variables and update them from within the data generator.
last_step = False # we will toggle this to True when we reach the end of the dataset
approx_progress = 0.0 # will go from 0 to 1 over the course of the epoch
def mid_data_generator ( split ) :
global last_step , approx_progress
assert split in { " train " , " val " } , " split must be ' train ' or ' val ' "
dataset = train_dataset if split == " train " else val_dataset
dataset_size = len ( dataset )
assert dataset_size > 0
needed_tokens = device_batch_size * max_seq_len + 1 # to form one training batch of inputs,targets
token_buffer = deque ( )
scratch = torch . empty ( needed_tokens , dtype = torch . int64 , pin_memory = True )
cursor = ddp_rank # increments by ddp_world_size each time, so each rank processes unique documents
while True :
# Accumulate enough tokens for one iteration before yielding
while len ( token_buffer ) < needed_tokens :
conversation = dataset [ cursor ]
ids , _ = tokenizer . render_conversation ( conversation )
token_buffer . extend ( ids )
cursor + = ddp_world_size
if cursor > = dataset_size :
cursor - = dataset_size # wrap around for another epoch
if split == " train " :
last_step = True # toggle last_step to True, which will terminate the training loop
# Build up inputs/targets and yield
for i in range ( needed_tokens ) :
scratch [ i ] = token_buffer . popleft ( )
inputs_cpu = scratch [ : - 1 ] . to ( dtype = torch . int32 )
targets_cpu = scratch [ 1 : ]
inputs = inputs_cpu . view ( device_batch_size , max_seq_len ) . to ( device = device , dtype = torch . int32 , non_blocking = True )
targets = targets_cpu . view ( device_batch_size , max_seq_len ) . to ( device = device , dtype = torch . int64 , non_blocking = True )
if split == " train " :
approx_progress = cursor / dataset_size # approximate progress as a fraction of the dataset
yield inputs , targets
train_loader = mid_data_generator ( " train " )
build_val_loader = lambda : mid_data_generator ( " val " )
progress = 0 # will go from 0 to 1 over the course of the epoch
# Learning rate scheduler
def get_lr_multiplier ( progress ) :
2025-10-15 16:35:04 +00:00
# first 80% of training: no decay, then linearly ramp down to 0.
return 1 if progress < 0.8 else 1 - ( progress - 0.8 ) / 0.2
2025-10-13 06:49:24 -07:00
# Momentum scheduler for Muon optimizer
def get_muon_momentum ( it ) :
frac = min ( it / 300 , 1 )
momentum = ( 1 - frac ) * 0.85 + frac * 0.95
return momentum
# -----------------------------------------------------------------------------
# Training loop
x , y = next ( train_loader ) # prefetch the very first batch of data
min_val_bpb = float ( " inf " )
smooth_train_loss = 0 # EMA of training loss
ema_beta = 0.9 # EMA decay factor
total_training_time = 0 # total wall-clock time of training
step = 0
while True :
flops_so_far = num_flops_per_token * total_batch_size * step
# Synchronize last_step across all ranks to avoid hangs in the distributed setting
if ddp :
last_step_tensor = torch . tensor ( last_step , dtype = torch . int32 , device = device )
dist . all_reduce ( last_step_tensor , op = dist . ReduceOp . MAX )
last_step = bool ( last_step_tensor . item ( ) )
# once in a while: evaluate the val bpb (all ranks participate)
if last_step or step % eval_every == 0 :
model . eval ( )
val_loader = build_val_loader ( )
eval_steps = eval_tokens / / ( device_batch_size * max_seq_len * ddp_world_size )
with autocast_ctx :
val_bpb = evaluate_bpb ( model , val_loader , eval_steps , token_bytes )
print0 ( f " Step { step : 05d } | Validation bpb: { val_bpb : .4f } " )
if val_bpb < min_val_bpb :
min_val_bpb = val_bpb
wandb_run . log ( {
" step " : step ,
" total_training_flops " : flops_so_far ,
" total_training_time " : total_training_time ,
" val/bpb " : val_bpb ,
} )
model . train ( )
# save checkpoint at the end of the run (only on master process)
2025-10-15 16:35:04 +00:00
if master_process and last_step and not dry_run :
2025-10-13 06:49:24 -07:00
output_dirname = f " d { depth } " # e.g. d12
checkpoint_dir = os . path . join ( base_dir , " mid_checkpoints " , output_dirname )
save_checkpoint (
checkpoint_dir ,
step ,
orig_model . state_dict ( ) ,
[ opt . state_dict ( ) for opt in optimizers ] , # TODO: make sure saving across ranks is done correctly
{
" step " : step ,
" val_bpb " : val_bpb , # loss at last step
" model_config " : {
" sequence_len " : max_seq_len ,
" vocab_size " : tokenizer . get_vocab_size ( ) ,
" n_layer " : depth ,
" n_head " : model . config . n_head ,
" n_kv_head " : model . config . n_kv_head ,
" n_embd " : model . config . n_embd ,
} ,
" user_config " : user_config , # inputs to the training script
}
)
if last_step :
break
# -------------------------------------------------------------------------
# single training step
# evaluate the gradient
torch . cuda . synchronize ( )
t0 = time . time ( )
for micro_step in range ( grad_accum_steps ) :
with autocast_ctx :
loss = model ( x , y )
train_loss = loss . detach ( ) # for logging
loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
loss . backward ( )
x , y = next ( train_loader ) # prefetch the next batch while the GPU is busy with forward/backward
progress = max ( progress , approx_progress ) # only increase progress monotonically
# step the optimizers
lrm = get_lr_multiplier ( progress )
for opt in optimizers :
for group in opt . param_groups :
group [ " lr " ] = group [ " initial_lr " ] * lrm
muon_momentum = get_muon_momentum ( step )
for group in muon_optimizer . param_groups :
group [ " momentum " ] = muon_momentum
for opt in optimizers :
opt . step ( )
model . zero_grad ( set_to_none = True )
torch . cuda . synchronize ( )
t1 = time . time ( )
dt = t1 - t0
# -------------------------------------------------------------------------
# State
step + = 1
# logging
smooth_train_loss = ema_beta * smooth_train_loss + ( 1 - ema_beta ) * train_loss . item ( ) # EMA the training loss
debiased_smooth_loss = smooth_train_loss / ( 1 - ema_beta * * ( step + 1 ) ) # debias the EMA
pct_done = 100 * progress
tok_per_sec = int ( world_tokens_per_fwdbwd / dt )
flops_per_sec = num_flops_per_token * total_batch_size / dt
promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity
mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in %
if step > 10 :
total_training_time + = dt # only count the time after the first 10 steps
print0 ( f " step { step : 05d } ( { pct_done : .2f } %) | loss: { debiased_smooth_loss : .6f } | lrm: { lrm : .2f } | dt: { dt * 1000 : .2f } ms | tok/sec: { tok_per_sec : , } | mfu: { mfu : .2f } | total time: { total_training_time / 60 : .2f } m " )
if step % 10 == 0 :
wandb_run . log ( {
" step " : step ,
" total_training_flops " : flops_so_far ,
" total_training_time " : total_training_time ,
" train/loss " : debiased_smooth_loss ,
" train/lrm " : lrm ,
" train/dt " : dt ,
" train/tok_per_sec " : tok_per_sec ,
" train/mfu " : mfu ,
} )
# print a few more stats
print0 ( f " Peak memory usage: { torch . cuda . max_memory_allocated ( ) / 1024 / 1024 : .2f } MiB " )
print0 ( f " Total training time: { total_training_time / 60 : .2f } m " )
print0 ( f " Minimum validation bpb: { min_val_bpb : .4f } " )
# Log to report
2025-10-15 16:35:04 +00:00
if not dry_run :
from nanochat . report import get_report
get_report ( ) . log ( section = " Midtraining " , data = [
user_config , # CLI args
{ # stats about the training setup
" Number of iterations " : step ,
" DDP world size " : ddp_world_size ,
} ,
{ # stats about training outcomes
" Minimum validation bpb " : min_val_bpb ,
}
] )
2025-10-13 06:49:24 -07:00
# cleanup
wandb_run . finish ( ) # wandb run finish
compute_cleanup ( )