2025-10-13 06:49:24 -07:00
"""
Midtrain the model. Same as pretraining but simpler.
Run as:
python -m scripts.mid_train
Or torchrun for training:
2026-01-13 22:45:27 +00:00
torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --device-batch-size=16
2025-10-13 06:49:24 -07:00
"""
2026-01-04 19:14:23 +00:00
import argparse
2025-10-13 06:49:24 -07:00
import os
2025-12-28 03:32:46 +00:00
os . environ [ " PYTORCH_ALLOC_CONF " ] = " expandable_segments:True "
2025-10-13 06:49:24 -07:00
import time
import wandb
import torch
2025-10-16 16:33:17 -07:00
from contextlib import nullcontext
from nanochat . common import compute_init , compute_cleanup , print0 , DummyWandb , get_base_dir , autodetect_device_type
2025-10-13 06:49:24 -07:00
from nanochat . tokenizer import get_token_bytes
from nanochat . checkpoint_manager import save_checkpoint
from nanochat . loss_eval import evaluate_bpb
from nanochat . checkpoint_manager import load_model
import torch . distributed as dist
from tasks . common import TaskMixture
from tasks . gsm8k import GSM8K
from tasks . mmlu import MMLU
from tasks . smoltalk import SmolTalk
2025-10-21 15:04:58 +00:00
from tasks . customjson import CustomJSON
2025-10-24 14:02:48 +00:00
from tasks . spellingbee import SimpleSpelling , SpellingBee
2025-10-13 06:49:24 -07:00
# -----------------------------------------------------------------------------
2026-01-04 19:14:23 +00:00
# CLI arguments
parser = argparse . ArgumentParser ( description = " Midtrain the model " )
# Logging
parser . add_argument ( " --run " , type = str , default = " dummy " , help = " wandb run name ( ' dummy ' disables wandb logging) " )
# Runtime
2026-01-13 22:45:27 +00:00
parser . add_argument ( " --device-type " , type = str , default = " " , help = " cuda|cpu|mps (empty = autodetect) " )
2026-01-04 19:14:23 +00:00
parser . add_argument ( " --dtype " , type = str , default = " bfloat16 " , help = " float32|bfloat16 " )
# Model loading
2026-01-13 22:45:27 +00:00
parser . add_argument ( " --model-tag " , type = str , default = None , help = " model tag to load from " )
parser . add_argument ( " --model-step " , type = int , default = None , help = " model step to load from " )
2026-01-04 19:14:23 +00:00
# Training horizon
2026-01-13 22:45:27 +00:00
parser . add_argument ( " --num-iterations " , type = int , default = - 1 , help = " number of optimization steps (-1 = full epoch) " )
2026-01-04 19:14:23 +00:00
# Batch sizes
2026-01-13 22:45:27 +00:00
parser . add_argument ( " --max-seq-len " , type = int , default = 2048 , help = " max context length " )
parser . add_argument ( " --device-batch-size " , type = int , default = 32 , help = " per-device batch size " )
parser . add_argument ( " --total-batch-size " , type = int , default = 524288 , help = " total batch size in tokens " )
2026-01-04 19:14:23 +00:00
# Optimization
2026-01-13 22:45:27 +00:00
parser . add_argument ( " --embedding-lr " , type = float , default = 0.2 , help = " learning rate for embedding parameters (Adam) " )
parser . add_argument ( " --unembedding-lr " , type = float , default = 0.004 , help = " learning rate for unembedding parameters (Adam) " )
parser . add_argument ( " --matrix-lr " , type = float , default = 0.02 , help = " learning rate for matrix parameters (Muon) " )
parser . add_argument ( " --weight-decay " , type = float , default = 0.0 , help = " weight decay for embedding/unembedding parameters (Adam) " )
parser . add_argument ( " --init-lr-frac " , type = float , default = 1.0 , help = " initial LR as fraction of base LR " )
2026-01-04 19:14:23 +00:00
# Evaluation
2026-01-13 22:45:27 +00:00
parser . add_argument ( " --eval-every " , type = int , default = 150 , help = " evaluate val bpb every N steps (-1 = disable) " )
parser . add_argument ( " --eval-tokens " , type = int , default = 20 * 524288 , help = " number of tokens to evaluate val loss on " )
2026-01-04 19:14:23 +00:00
# Output
2026-01-13 22:45:27 +00:00
parser . add_argument ( " --dry-run " , action = " store_true " , help = " log to wandb but skip checkpoints/report " )
2026-01-04 19:14:23 +00:00
args = parser . parse_args ( )
user_config = vars ( args ) . copy ( )
2025-10-13 06:49:24 -07:00
# -----------------------------------------------------------------------------
# Compute init
2026-01-04 19:14:23 +00:00
device_type = autodetect_device_type ( ) if args . device_type == " " else args . device_type
2025-10-16 16:33:17 -07:00
ddp , ddp_rank , ddp_local_rank , ddp_world_size , device = compute_init ( device_type )
2025-10-13 06:49:24 -07:00
master_process = ddp_rank == 0
2026-01-04 19:14:23 +00:00
ptdtype = torch . float32 if args . dtype == ' float32 ' else torch . bfloat16
autocast_ctx = torch . amp . autocast ( device_type = device_type , dtype = ptdtype ) if device_type == " cuda " else nullcontext ( )
2025-10-16 16:33:17 -07:00
synchronize = torch . cuda . synchronize if device_type == " cuda " else lambda : None
get_max_memory = torch . cuda . max_memory_allocated if device_type == " cuda " else lambda : 0
2025-10-13 06:49:24 -07:00
# wandb logging init
2026-01-04 19:14:23 +00:00
use_dummy_wandb = args . run == " dummy " or not master_process
wandb_run = DummyWandb ( ) if use_dummy_wandb else wandb . init ( project = " nanochat-mid " , name = args . run , config = user_config )
2025-10-13 06:49:24 -07:00
# Load the model and tokenizer
2026-01-04 19:14:23 +00:00
model , tokenizer , meta = load_model ( " base " , device , phase = " train " , model_tag = args . model_tag , step = args . model_step )
2025-10-13 06:49:24 -07:00
pretrain_batch_size = meta . get ( " device_batch_size " , None )
2026-01-04 19:14:23 +00:00
if pretrain_batch_size is not None and args . device_batch_size > pretrain_batch_size :
2026-01-13 22:45:27 +00:00
print0 ( f " FOOTGUN WARNING: base model training used device_batch_size { pretrain_batch_size } , did you pass in a good --device-batch-size to this script? " )
2025-10-13 06:49:24 -07:00
orig_model = model
model = torch . compile ( model , dynamic = False )
depth = model . config . n_layer
num_flops_per_token = model . estimate_flops ( )
2026-01-04 19:14:23 +00:00
tokens_per_fwdbwd = args . device_batch_size * args . max_seq_len # tokens per iteration for a single rank
2025-10-13 06:49:24 -07:00
world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks
2026-01-04 19:14:23 +00:00
assert args . total_batch_size % world_tokens_per_fwdbwd == 0
grad_accum_steps = args . total_batch_size / / world_tokens_per_fwdbwd
print0 ( f " Tokens / micro-batch / rank: { args . device_batch_size } x { args . max_seq_len } = { tokens_per_fwdbwd : , } " )
2025-10-13 06:49:24 -07:00
print0 ( f " Tokens / micro-batch: { world_tokens_per_fwdbwd : , } " )
2026-01-04 19:14:23 +00:00
print0 ( f " Total batch size { args . total_batch_size : , } => gradient accumulation steps: { grad_accum_steps } " )
2025-10-13 06:49:24 -07:00
token_bytes = get_token_bytes ( device = device )
2026-01-29 00:50:50 +00:00
# Initialize the Optimizer (combined MuonAdamW: Muon for matrix params, AdamW for rest)
optimizer = model . setup_optimizer ( unembedding_lr = args . unembedding_lr , embedding_lr = args . embedding_lr , matrix_lr = args . matrix_lr , weight_decay = args . weight_decay )
2025-10-13 06:49:24 -07:00
# Override the initial learning rate as a fraction of the base learning rate
2026-01-29 00:50:50 +00:00
for group in optimizer . param_groups :
group [ " lr " ] = group [ " lr " ] * args . init_lr_frac
group [ " initial_lr " ] = group [ " lr " ]
2025-10-13 06:49:24 -07:00
# Midtraining data mixture and DataLoader
base_dir = get_base_dir ( )
2025-10-21 15:04:58 +00:00
identity_conversations_filepath = os . path . join ( base_dir , " identity_conversations.jsonl " )
2025-10-13 06:49:24 -07:00
train_dataset = TaskMixture ( [
SmolTalk ( split = " train " ) , # 460K rows of general conversations
MMLU ( subset = " auxiliary_train " , split = " train " ) , # 100K rows of multiple choice problems drawn from ARC, MC_TEST, OBQA, RACE
GSM8K ( subset = " main " , split = " train " ) , # 8K rows teaching simple math and (calculator) tool use
2025-10-21 15:04:58 +00:00
CustomJSON ( filepath = identity_conversations_filepath ) , # 1000 rows of synthetic identity conversations
CustomJSON ( filepath = identity_conversations_filepath ) , # let's do 2 epochs of these
2025-10-24 14:02:48 +00:00
SimpleSpelling ( size = 200000 , split = " train " ) , # 200K rows of Simple Spelling (e.g. spell the word 'apple')
SpellingBee ( size = 80000 , split = " train " ) , # 80K rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?)
] ) # total: 460K + 100K + 8K + 200K + 80K = 848K rows
2025-10-13 06:49:24 -07:00
val_dataset = TaskMixture ( [
SmolTalk ( split = " test " ) , # 24K rows in test set
MMLU ( subset = " all " , split = " test " , stop = 5200 ) , # 14K rows in test set, use only 5.2K to match the train ratios
GSM8K ( subset = " main " , split = " test " , stop = 420 ) , # 1.32K rows in test set, use only 420 to match the train ratios
] ) # total: 24K + 14K + 1.32K ~= 39K rows
# DataLoader is defined here, it emits inputs, targets : 2D tensors of shape (device_batch_size, max_seq_len)
# A big problem is that we don't know the final num_iterations in advance. So we create
# these two global variables and update them from within the data generator.
2025-11-21 13:19:45 +01:00
last_step = False # we will toggle this to True when we reach the end of the training dataset
2025-10-13 06:49:24 -07:00
approx_progress = 0.0 # will go from 0 to 1 over the course of the epoch
2026-01-13 20:05:47 +00:00
current_epoch = 1 # track epoch for logging
def mid_data_generator_bos_bestfit ( split , buffer_size = 100 ) :
"""
2026-01-31 18:21:36 +00:00
BOS-aligned dataloader for midtraining with bestfit-pad packing.
2026-01-13 20:05:47 +00:00
Each row in the batch starts with BOS (beginning of a conversation).
2026-01-31 18:21:36 +00:00
Conversations are packed using best-fit algorithm. When no conversation fits,
the row is padded (instead of cropping) to ensure no tokens are ever discarded.
Padding positions have targets masked with -1 (ignore_index for cross-entropy).
2026-01-13 20:05:47 +00:00
"""
global last_step , approx_progress , current_epoch
2025-10-13 06:49:24 -07:00
assert split in { " train " , " val " } , " split must be ' train ' or ' val ' "
dataset = train_dataset if split == " train " else val_dataset
dataset_size = len ( dataset )
assert dataset_size > 0
2026-01-13 20:05:47 +00:00
row_capacity = args . max_seq_len + 1 # +1 for target at last position
2026-01-31 18:21:36 +00:00
bos_token = tokenizer . get_bos_token_id ( )
2026-01-13 20:05:47 +00:00
# Conversation buffer: list of token lists
conv_buffer = [ ]
2026-01-13 22:45:27 +00:00
cursor = ddp_rank # Each rank processes different conversations (for fetching)
consumed = ddp_rank # Track actual consumption separately from buffering
2026-01-13 20:05:47 +00:00
epoch = 1
it = 0 # iteration counter
def refill_buffer ( ) :
nonlocal cursor , epoch
while len ( conv_buffer ) < buffer_size :
2025-10-13 06:49:24 -07:00
conversation = dataset [ cursor ]
ids , _ = tokenizer . render_conversation ( conversation )
2026-01-13 20:05:47 +00:00
conv_buffer . append ( ids )
2025-10-13 06:49:24 -07:00
cursor + = ddp_world_size
if cursor > = dataset_size :
2026-01-13 20:05:47 +00:00
cursor = cursor % dataset_size
epoch + = 1
2026-01-13 22:45:27 +00:00
# Note: last_step is now triggered based on consumption, not fetching
2026-01-13 20:05:47 +00:00
while True :
rows = [ ]
2026-01-31 18:21:36 +00:00
row_lengths = [ ] # Track actual content length (excluding padding) for each row
2026-01-13 20:05:47 +00:00
for _ in range ( args . device_batch_size ) :
row = [ ]
2026-01-31 18:21:36 +00:00
padded = False
2026-01-13 20:05:47 +00:00
while len ( row ) < row_capacity :
# Ensure buffer has conversations
while len ( conv_buffer ) < buffer_size :
refill_buffer ( )
remaining = row_capacity - len ( row )
# Find largest conversation that fits entirely
best_idx = - 1
best_len = 0
for i , conv in enumerate ( conv_buffer ) :
conv_len = len ( conv )
if conv_len < = remaining and conv_len > best_len :
best_idx = i
best_len = conv_len
if best_idx > = 0 :
# Found a conversation that fits - use it entirely
conv = conv_buffer . pop ( best_idx )
row . extend ( conv )
2026-01-13 22:45:27 +00:00
consumed + = ddp_world_size # Track actual consumption
2026-01-13 20:05:47 +00:00
else :
2026-01-31 18:21:36 +00:00
# No conversation fits - pad the remainder instead of cropping
# This ensures we never discard any tokens
content_len = len ( row )
row . extend ( [ bos_token ] * remaining ) # Pad with BOS tokens
padded = True
break # Row is now full (with padding)
# Track content length: full row if no padding, otherwise the length before padding
if padded :
row_lengths . append ( content_len )
else :
row_lengths . append ( row_capacity )
2026-01-13 20:05:47 +00:00
rows . append ( row [ : row_capacity ] )
2025-10-20 10:15:17 -07:00
# Stopping condition to respect num_iterations, if given
it + = 1
2026-01-04 19:14:23 +00:00
if 0 < args . num_iterations < = it and split == " train " :
2026-01-13 20:05:47 +00:00
last_step = True
2026-01-13 22:45:27 +00:00
# Update progress tracking (based on consumed, not cursor, to account for buffering)
2025-10-13 06:49:24 -07:00
if split == " train " :
2026-01-13 20:05:47 +00:00
current_epoch = epoch
2026-01-04 19:14:23 +00:00
if args . num_iterations > 0 :
2026-01-13 20:05:47 +00:00
approx_progress = it / args . num_iterations
2025-10-20 10:15:17 -07:00
else :
2026-01-13 22:45:27 +00:00
approx_progress = consumed / dataset_size
# Trigger last_step when we've consumed enough (instead of when cursor wraps)
if consumed > = dataset_size :
last_step = True
2026-01-13 20:05:47 +00:00
# Build tensors
use_cuda = device_type == " cuda "
batch_tensor = torch . tensor ( rows , dtype = torch . long , pin_memory = use_cuda )
inputs = batch_tensor [ : , : - 1 ] . to ( device = device , dtype = torch . int32 , non_blocking = use_cuda )
targets = batch_tensor [ : , 1 : ] . to ( device = device , dtype = torch . int64 , non_blocking = use_cuda )
2026-01-31 18:21:36 +00:00
# Mask out padding positions in targets (set to -1 = ignore_index)
# For each row, positions >= (content_length - 1) in targets should be masked
for i , content_len in enumerate ( row_lengths ) :
if content_len < row_capacity :
targets [ i , content_len - 1 : ] = - 1
2025-10-13 06:49:24 -07:00
yield inputs , targets
2026-01-13 20:05:47 +00:00
train_loader = mid_data_generator_bos_bestfit ( " train " )
build_val_loader = lambda : mid_data_generator_bos_bestfit ( " val " )
2025-10-13 06:49:24 -07:00
progress = 0 # will go from 0 to 1 over the course of the epoch
# Learning rate scheduler
def get_lr_multiplier ( progress ) :
2025-10-15 16:35:04 +00:00
# first 80% of training: no decay, then linearly ramp down to 0.
return 1 if progress < 0.8 else 1 - ( progress - 0.8 ) / 0.2
2025-10-13 06:49:24 -07:00
# Momentum scheduler for Muon optimizer
def get_muon_momentum ( it ) :
frac = min ( it / 300 , 1 )
momentum = ( 1 - frac ) * 0.85 + frac * 0.95
return momentum
# -----------------------------------------------------------------------------
# Training loop
x , y = next ( train_loader ) # prefetch the very first batch of data
min_val_bpb = float ( " inf " )
smooth_train_loss = 0 # EMA of training loss
ema_beta = 0.9 # EMA decay factor
total_training_time = 0 # total wall-clock time of training
step = 0
while True :
2026-01-04 19:14:23 +00:00
flops_so_far = num_flops_per_token * args . total_batch_size * step
2025-10-13 06:49:24 -07:00
# Synchronize last_step across all ranks to avoid hangs in the distributed setting
if ddp :
last_step_tensor = torch . tensor ( last_step , dtype = torch . int32 , device = device )
dist . all_reduce ( last_step_tensor , op = dist . ReduceOp . MAX )
last_step = bool ( last_step_tensor . item ( ) )
# once in a while: evaluate the val bpb (all ranks participate)
2026-01-17 08:26:43 +05:30
if last_step or ( args . eval_every > 0 and step % args . eval_every == 0 ) :
2025-10-13 06:49:24 -07:00
model . eval ( )
val_loader = build_val_loader ( )
2026-01-04 19:14:23 +00:00
eval_steps = args . eval_tokens / / ( args . device_batch_size * args . max_seq_len * ddp_world_size )
2025-10-13 06:49:24 -07:00
with autocast_ctx :
val_bpb = evaluate_bpb ( model , val_loader , eval_steps , token_bytes )
print0 ( f " Step { step : 05d } | Validation bpb: { val_bpb : .4f } " )
if val_bpb < min_val_bpb :
min_val_bpb = val_bpb
wandb_run . log ( {
" step " : step ,
" total_training_flops " : flops_so_far ,
" total_training_time " : total_training_time ,
" val/bpb " : val_bpb ,
} )
model . train ( )
# save checkpoint at the end of the run (only on master process)
2026-01-04 19:14:23 +00:00
if master_process and last_step and not args . dry_run :
output_dirname = args . model_tag if args . model_tag else f " d { depth } " # e.g. d12
2025-10-13 06:49:24 -07:00
checkpoint_dir = os . path . join ( base_dir , " mid_checkpoints " , output_dirname )
save_checkpoint (
checkpoint_dir ,
step ,
orig_model . state_dict ( ) ,
2026-01-29 00:50:50 +00:00
optimizer . state_dict ( ) ,
2025-10-13 06:49:24 -07:00
{
" step " : step ,
" val_bpb " : val_bpb , # loss at last step
" model_config " : {
2026-01-04 19:14:23 +00:00
" sequence_len " : args . max_seq_len ,
2025-10-13 06:49:24 -07:00
" vocab_size " : tokenizer . get_vocab_size ( ) ,
" n_layer " : depth ,
" n_head " : model . config . n_head ,
" n_kv_head " : model . config . n_kv_head ,
" n_embd " : model . config . n_embd ,
} ,
" user_config " : user_config , # inputs to the training script
}
)
if last_step :
break
# -------------------------------------------------------------------------
# single training step
# evaluate the gradient
2025-10-16 16:33:17 -07:00
synchronize ( )
2025-10-13 06:49:24 -07:00
t0 = time . time ( )
for micro_step in range ( grad_accum_steps ) :
with autocast_ctx :
loss = model ( x , y )
train_loss = loss . detach ( ) # for logging
loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
loss . backward ( )
x , y = next ( train_loader ) # prefetch the next batch while the GPU is busy with forward/backward
progress = max ( progress , approx_progress ) # only increase progress monotonically
2026-01-29 00:50:50 +00:00
# step the optimizer
2025-10-13 06:49:24 -07:00
lrm = get_lr_multiplier ( progress )
muon_momentum = get_muon_momentum ( step )
2026-01-29 00:50:50 +00:00
for group in optimizer . param_groups :
group [ " lr " ] = group [ " initial_lr " ] * lrm
if group [ ' kind ' ] == ' muon ' :
group [ " momentum " ] = muon_momentum
optimizer . step ( )
2025-10-13 06:49:24 -07:00
model . zero_grad ( set_to_none = True )
2025-10-16 16:33:17 -07:00
synchronize ( )
2025-10-13 06:49:24 -07:00
t1 = time . time ( )
dt = t1 - t0
# -------------------------------------------------------------------------
# State
step + = 1
# logging
smooth_train_loss = ema_beta * smooth_train_loss + ( 1 - ema_beta ) * train_loss . item ( ) # EMA the training loss
debiased_smooth_loss = smooth_train_loss / ( 1 - ema_beta * * ( step + 1 ) ) # debias the EMA
pct_done = 100 * progress
2026-01-04 19:14:23 +00:00
tok_per_sec = int ( args . total_batch_size / dt )
flops_per_sec = num_flops_per_token * args . total_batch_size / dt
2025-10-13 06:49:24 -07:00
promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity
mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in %
if step > 10 :
total_training_time + = dt # only count the time after the first 10 steps
2026-01-13 20:05:47 +00:00
print0 ( f " step { step : 05d } ( { pct_done : .2f } %) | loss: { debiased_smooth_loss : .6f } | lrm: { lrm : .2f } | dt: { dt * 1000 : .2f } ms | tok/sec: { tok_per_sec : , } | mfu: { mfu : .2f } | epoch: { current_epoch } | total time: { total_training_time / 60 : .2f } m " )
2025-10-13 06:49:24 -07:00
if step % 10 == 0 :
wandb_run . log ( {
" step " : step ,
" total_training_flops " : flops_so_far ,
" total_training_time " : total_training_time ,
" train/loss " : debiased_smooth_loss ,
" train/lrm " : lrm ,
" train/dt " : dt ,
" train/tok_per_sec " : tok_per_sec ,
" train/mfu " : mfu ,
2026-01-13 20:05:47 +00:00
" train/epoch " : current_epoch ,
2025-10-13 06:49:24 -07:00
} )
# print a few more stats
2025-10-16 16:33:17 -07:00
print0 ( f " Peak memory usage: { get_max_memory ( ) / 1024 / 1024 : .2f } MiB " )
2025-10-13 06:49:24 -07:00
print0 ( f " Total training time: { total_training_time / 60 : .2f } m " )
print0 ( f " Minimum validation bpb: { min_val_bpb : .4f } " )
# Log to report
2026-01-04 19:14:23 +00:00
if not args . dry_run :
2025-10-15 16:35:04 +00:00
from nanochat . report import get_report
get_report ( ) . log ( section = " Midtraining " , data = [
user_config , # CLI args
{ # stats about the training setup
" Number of iterations " : step ,
" DDP world size " : ddp_world_size ,
} ,
{ # stats about training outcomes
" Minimum validation bpb " : min_val_bpb ,
}
] )
2025-10-13 06:49:24 -07:00
# cleanup
wandb_run . finish ( ) # wandb run finish
compute_cleanup ( )