2025-10-13 06:49:24 -07:00
"""
2026-01-31 19:12:25 +00:00
Supervised fine-tuning (SFT) the model.
Run as:
2025-10-13 06:49:24 -07:00
python -m scripts.chat_sft
Or torchrun for training:
2026-01-31 19:12:25 +00:00
torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft -- --device-batch-size=16
2025-10-13 06:49:24 -07:00
"""
2026-02-16 14:41:53 +00:00
import gc
2026-01-04 19:14:23 +00:00
import argparse
2025-10-13 06:49:24 -07:00
import os
2025-12-28 03:32:46 +00:00
os . environ [ " PYTORCH_ALLOC_CONF " ] = " expandable_segments:True "
2026-01-31 19:12:25 +00:00
import time
2025-10-13 06:49:24 -07:00
import wandb
import torch
2025-10-20 10:15:17 -07:00
from contextlib import nullcontext
2026-02-16 14:41:53 +00:00
from nanochat . common import compute_init , compute_cleanup , print0 , DummyWandb , get_base_dir , autodetect_device_type , get_peak_flops
2026-01-31 19:12:25 +00:00
from nanochat . tokenizer import get_token_bytes
2026-02-16 14:41:53 +00:00
from nanochat . checkpoint_manager import save_checkpoint , load_model , load_optimizer_state
2026-01-31 19:12:25 +00:00
from nanochat . loss_eval import evaluate_bpb
import torch . distributed as dist
2026-02-16 14:41:53 +00:00
from nanochat . flash_attention import HAS_FA3
from nanochat . engine import Engine
from scripts . chat_eval import run_chat_eval
2025-10-13 06:49:24 -07:00
2025-10-15 16:42:23 +00:00
from tasks . common import TaskMixture
2025-10-13 06:49:24 -07:00
from tasks . gsm8k import GSM8K
2026-01-31 19:12:25 +00:00
from tasks . mmlu import MMLU
2025-10-13 06:49:24 -07:00
from tasks . smoltalk import SmolTalk
2025-10-21 15:04:58 +00:00
from tasks . customjson import CustomJSON
2025-10-24 14:02:48 +00:00
from tasks . spellingbee import SimpleSpelling , SpellingBee
2025-10-13 06:49:24 -07:00
# -----------------------------------------------------------------------------
2026-01-04 19:14:23 +00:00
# CLI arguments
2026-01-31 19:12:25 +00:00
parser = argparse . ArgumentParser ( description = " Supervised fine-tuning (SFT) the model " )
2026-01-04 19:14:23 +00:00
# Logging
parser . add_argument ( " --run " , type = str , default = " dummy " , help = " wandb run name ( ' dummy ' disables wandb logging) " )
# Runtime
2026-01-13 22:45:27 +00:00
parser . add_argument ( " --device-type " , type = str , default = " " , help = " cuda|cpu|mps (empty = autodetect) " )
2026-01-04 19:14:23 +00:00
# Model loading
2026-01-13 22:45:27 +00:00
parser . add_argument ( " --model-tag " , type = str , default = None , help = " model tag to load from " )
parser . add_argument ( " --model-step " , type = int , default = None , help = " model step to load from " )
2026-02-16 20:23:04 +00:00
parser . add_argument ( " --load-optimizer " , type = int , default = 1 , help = " warm-start optimizer from pretrained checkpoint (0=no, 1=yes) " )
2026-01-04 19:14:23 +00:00
# Training horizon
2026-01-31 19:12:25 +00:00
parser . add_argument ( " --num-iterations " , type = int , default = - 1 , help = " number of optimization steps (-1 = full epoch) " )
2026-02-16 14:41:53 +00:00
# Batch sizes (default: inherit from pretrained checkpoint)
parser . add_argument ( " --max-seq-len " , type = int , default = None , help = " max context length (default: inherit from pretrain) " )
parser . add_argument ( " --device-batch-size " , type = int , default = None , help = " per-device batch size (default: inherit from pretrain) " )
parser . add_argument ( " --total-batch-size " , type = int , default = None , help = " total batch size in tokens (default: inherit from pretrain) " )
# Optimization (default: inherit from pretrained checkpoint)
parser . add_argument ( " --embedding-lr " , type = float , default = None , help = " learning rate for embedding parameters (Adam) (default: inherit from pretrain) " )
parser . add_argument ( " --unembedding-lr " , type = float , default = None , help = " learning rate for unembedding parameters (Adam) (default: inherit from pretrain) " )
parser . add_argument ( " --matrix-lr " , type = float , default = None , help = " learning rate for matrix parameters (Muon) (default: inherit from pretrain) " )
parser . add_argument ( " --init-lr-frac " , type = float , default = 0.8 , help = " initial LR as fraction of base LR " )
parser . add_argument ( " --warmup-ratio " , type = float , default = 0.0 , help = " ratio of iterations for LR warmup " )
parser . add_argument ( " --warmdown-ratio " , type = float , default = 0.5 , help = " ratio of iterations for LR warmdown " )
parser . add_argument ( " --final-lr-frac " , type = float , default = 0.0 , help = " final LR as fraction of initial LR " )
2026-01-04 19:14:23 +00:00
# Evaluation
2026-02-16 14:41:53 +00:00
parser . add_argument ( " --eval-every " , type = int , default = 200 , help = " evaluate val bpb every N steps (-1 = disable) " )
parser . add_argument ( " --eval-tokens " , type = int , default = 40 * 524288 , help = " number of tokens to evaluate val loss on " )
parser . add_argument ( " --chatcore-every " , type = int , default = 200 , help = " evaluate ChatCORE metric every N steps (-1 = disable) " )
parser . add_argument ( " --chatcore-max-cat " , type = int , default = - 1 , help = " max problems per categorical task for ChatCORE " )
parser . add_argument ( " --chatcore-max-sample " , type = int , default = 24 , help = " max problems per generative task for ChatCORE " )
2026-02-16 20:23:04 +00:00
# Data mixture
parser . add_argument ( " --mmlu-epochs " , type = int , default = 3 , help = " number of epochs of MMLU in training mixture (teaches Multiple Choice) " )
parser . add_argument ( " --gsm8k-epochs " , type = int , default = 4 , help = " number of epochs of GSM8K in training mixture (teaches Math and Tool Use) " )
2026-01-04 19:14:23 +00:00
args = parser . parse_args ( )
user_config = vars ( args ) . copy ( )
2025-10-13 06:49:24 -07:00
# -----------------------------------------------------------------------------
# Compute init
2026-01-04 19:14:23 +00:00
device_type = autodetect_device_type ( ) if args . device_type == " " else args . device_type
2025-10-20 10:15:17 -07:00
ddp , ddp_rank , ddp_local_rank , ddp_world_size , device = compute_init ( device_type )
2025-10-13 06:49:24 -07:00
master_process = ddp_rank == 0
2026-02-16 14:41:53 +00:00
autocast_ctx = torch . amp . autocast ( device_type = device_type , dtype = torch . bfloat16 ) if device_type == " cuda " else nullcontext ( )
2026-01-31 19:12:25 +00:00
synchronize = torch . cuda . synchronize if device_type == " cuda " else lambda : None
get_max_memory = torch . cuda . max_memory_allocated if device_type == " cuda " else lambda : 0
2026-02-16 14:41:53 +00:00
if device_type == " cuda " :
gpu_device_name = torch . cuda . get_device_name ( 0 )
gpu_peak_flops = get_peak_flops ( gpu_device_name )
print0 ( f " GPU: { gpu_device_name } | Peak FLOPS (BF16): { gpu_peak_flops : .2e } " )
else :
gpu_peak_flops = float ( ' inf ' ) # MFU not meaningful for CPU/MPS
2025-10-13 06:49:24 -07:00
# wandb logging init
2026-01-04 19:14:23 +00:00
use_dummy_wandb = args . run == " dummy " or not master_process
2026-01-31 19:12:25 +00:00
wandb_run = DummyWandb ( ) if use_dummy_wandb else wandb . init ( project = " nanochat-sft " , name = args . run , config = user_config )
2025-10-13 06:49:24 -07:00
2026-02-16 14:41:53 +00:00
# Flash Attention status
if not HAS_FA3 :
print0 ( " WARNING: Flash Attention 3 not available, using PyTorch SDPA fallback. Training will be less efficient. " )
2025-10-13 06:49:24 -07:00
# Load the model and tokenizer
2026-01-31 19:12:25 +00:00
model , tokenizer , meta = load_model ( " base " , device , phase = " train " , model_tag = args . model_tag , step = args . model_step )
2026-02-16 14:41:53 +00:00
# Inherit training hyperparameters from pretrained checkpoint (None = inherit, explicit value = override)
pretrain_user_config = meta . get ( " user_config " , { } )
for name , fallback , source in [
( " max_seq_len " , 2048 , meta ) ,
( " device_batch_size " , 32 , meta ) ,
( " total_batch_size " , 524288 , meta ) ,
( " embedding_lr " , 0.3 , pretrain_user_config ) ,
( " unembedding_lr " , 0.004 , pretrain_user_config ) ,
( " matrix_lr " , 0.02 , pretrain_user_config ) ,
] :
arg_val = getattr ( args , name )
pretrain_val = source . get ( name )
if arg_val is None :
resolved = pretrain_val if pretrain_val is not None else fallback
setattr ( args , name , resolved )
print0 ( f " Inherited { name } = { resolved } from pretrained checkpoint " )
elif pretrain_val is not None and arg_val != pretrain_val :
print0 ( f " NOTE: -- { name . replace ( ' _ ' , ' - ' ) } = { arg_val } overrides pretrained value of { pretrain_val } " )
else :
print0 ( f " Using { name } = { arg_val } " )
2026-01-31 19:12:25 +00:00
orig_model = model
model = torch . compile ( model , dynamic = False )
depth = model . config . n_layer
num_flops_per_token = model . estimate_flops ( )
tokens_per_fwdbwd = args . device_batch_size * args . max_seq_len # tokens per iteration for a single rank
world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks
assert args . total_batch_size % world_tokens_per_fwdbwd == 0
grad_accum_steps = args . total_batch_size / / world_tokens_per_fwdbwd
print0 ( f " Tokens / micro-batch / rank: { args . device_batch_size } x { args . max_seq_len } = { tokens_per_fwdbwd : , } " )
print0 ( f " Tokens / micro-batch: { world_tokens_per_fwdbwd : , } " )
print0 ( f " Total batch size { args . total_batch_size : , } => gradient accumulation steps: { grad_accum_steps } " )
token_bytes = get_token_bytes ( device = device )
2025-10-13 06:49:24 -07:00
2026-01-31 19:12:25 +00:00
# Initialize the Optimizer (combined MuonAdamW: Muon for matrix params, AdamW for rest)
2026-02-16 14:41:53 +00:00
# Note that pretraining ramps weight_decay to zero by end of pretraining, so SFT continues with zero
optimizer = model . setup_optimizer ( unembedding_lr = args . unembedding_lr , embedding_lr = args . embedding_lr , matrix_lr = args . matrix_lr , weight_decay = 0.0 )
# Optionally warm-start optimizer from pretrained checkpoint (momentum buffers etc.)
2026-02-16 20:23:04 +00:00
# Note: load_state_dict overwrites param_group metadata (LRs, betas, etc.) with the
# pretrained values. Since pretraining warmdown brings LRs to ~0, we must save and
# restore our fresh SFT LRs after loading.
2026-02-16 14:41:53 +00:00
base_dir = get_base_dir ( )
if args . load_optimizer :
optimizer_data = load_optimizer_state ( " base " , device , rank = ddp_rank , model_tag = args . model_tag , step = args . model_step )
2026-02-16 20:23:04 +00:00
if optimizer_data is not None :
base_lrs = [ group [ " lr " ] for group in optimizer . param_groups ]
optimizer . load_state_dict ( optimizer_data )
del optimizer_data
for group , base_lr in zip ( optimizer . param_groups , base_lrs ) :
group [ " lr " ] = base_lr
print0 ( " Loaded optimizer state from pretrained checkpoint (momentum buffers only, LRs reset) " )
else :
print0 ( " WARNING: optimizer checkpoint not found, starting with fresh optimizer (slightly worse) " )
2026-02-16 14:41:53 +00:00
2026-01-31 19:12:25 +00:00
# Override the initial learning rate as a fraction of the base learning rate
2026-01-29 00:50:50 +00:00
for group in optimizer . param_groups :
group [ " lr " ] = group [ " lr " ] * args . init_lr_frac
group [ " initial_lr " ] = group [ " lr " ]
2025-10-13 06:49:24 -07:00
2026-01-31 19:12:25 +00:00
# SFT data mixture and DataLoader
identity_conversations_filepath = os . path . join ( base_dir , " identity_conversations.jsonl " )
2026-02-16 20:23:04 +00:00
train_tasks = [
2026-01-31 19:12:25 +00:00
SmolTalk ( split = " train " ) , # 460K rows of general conversations
CustomJSON ( filepath = identity_conversations_filepath ) , # 1000 rows of synthetic identity conversations
2026-02-16 20:23:04 +00:00
CustomJSON ( filepath = identity_conversations_filepath ) , # 2 epochs of these
* [ MMLU ( subset = " auxiliary_train " , split = " train " ) for _ in range ( args . mmlu_epochs ) ] , # 100K rows per epoch
* [ GSM8K ( subset = " main " , split = " train " ) for _ in range ( args . gsm8k_epochs ) ] , # 8K rows per epoch
2026-01-31 19:12:25 +00:00
SimpleSpelling ( size = 200000 , split = " train " ) , # 200K rows of Simple Spelling (e.g. spell the word 'apple')
SpellingBee ( size = 80000 , split = " train " ) , # 80K rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?)
2026-02-16 20:23:04 +00:00
]
train_dataset = TaskMixture ( train_tasks )
print0 ( f " Training mixture: { len ( train_dataset ) : , } rows (MMLU x { args . mmlu_epochs } , GSM8K x { args . gsm8k_epochs } ) " )
2026-01-31 19:12:25 +00:00
val_dataset = TaskMixture ( [
SmolTalk ( split = " test " ) , # 24K rows in test set
MMLU ( subset = " all " , split = " test " , stop = 5200 ) , # 14K rows in test set, use only 5.2K to match the train ratios
GSM8K ( subset = " main " , split = " test " , stop = 420 ) , # 1.32K rows in test set, use only 420 to match the train ratios
] ) # total: 24K + 14K + 1.32K ~= 39K rows
# DataLoader is defined here, it emits inputs, targets : 2D tensors of shape (device_batch_size, max_seq_len)
# A big problem is that we don't know the final num_iterations in advance. So we create
# these two global variables and update them from within the data generator.
last_step = False # we will toggle this to True when we reach the end of the training dataset
approx_progress = 0.0 # will go from 0 to 1 over the course of the epoch
current_epoch = 1 # track epoch for logging
def sft_data_generator_bos_bestfit ( split , buffer_size = 100 ) :
"""
BOS-aligned dataloader for SFT with bestfit-pad packing.
Each row in the batch starts with BOS (beginning of a conversation).
Conversations are packed using best-fit algorithm. When no conversation fits,
the row is padded (instead of cropping) to ensure no tokens are ever discarded.
Padding positions have targets masked with -1 (ignore_index for cross-entropy).
"""
global last_step , approx_progress , current_epoch
assert split in { " train " , " val " } , " split must be ' train ' or ' val ' "
dataset = train_dataset if split == " train " else val_dataset
dataset_size = len ( dataset )
assert dataset_size > 0
row_capacity = args . max_seq_len + 1 # +1 for target at last position
bos_token = tokenizer . get_bos_token_id ( )
2026-03-03 06:07:47 +05:30
# Conversation buffer: list of (token_ids, loss_mask) tuples
2026-01-31 19:12:25 +00:00
conv_buffer = [ ]
cursor = ddp_rank # Each rank processes different conversations (for fetching)
consumed = ddp_rank # Track actual consumption separately from buffering
epoch = 1
it = 0 # iteration counter
def refill_buffer ( ) :
nonlocal cursor , epoch
while len ( conv_buffer ) < buffer_size :
conversation = dataset [ cursor ]
2026-03-03 06:07:47 +05:30
ids , mask = tokenizer . render_conversation ( conversation )
conv_buffer . append ( ( ids , mask ) )
2026-01-31 19:12:25 +00:00
cursor + = ddp_world_size
if cursor > = dataset_size :
cursor = cursor % dataset_size
epoch + = 1
# Note: last_step is now triggered based on consumption, not fetching
while True :
rows = [ ]
2026-03-03 06:07:47 +05:30
mask_rows = [ ]
2026-01-31 19:12:25 +00:00
row_lengths = [ ] # Track actual content length (excluding padding) for each row
for _ in range ( args . device_batch_size ) :
row = [ ]
2026-03-03 06:07:47 +05:30
mask_row = [ ]
2026-01-31 19:12:25 +00:00
padded = False
while len ( row ) < row_capacity :
# Ensure buffer has conversations
while len ( conv_buffer ) < buffer_size :
refill_buffer ( )
remaining = row_capacity - len ( row )
# Find largest conversation that fits entirely
best_idx = - 1
best_len = 0
2026-03-03 06:07:47 +05:30
for i , ( conv , _ ) in enumerate ( conv_buffer ) :
2026-01-31 19:12:25 +00:00
conv_len = len ( conv )
if conv_len < = remaining and conv_len > best_len :
best_idx = i
best_len = conv_len
if best_idx > = 0 :
# Found a conversation that fits - use it entirely
2026-03-03 06:07:47 +05:30
conv , conv_mask = conv_buffer . pop ( best_idx )
2026-01-31 19:12:25 +00:00
row . extend ( conv )
2026-03-03 06:07:47 +05:30
mask_row . extend ( conv_mask )
2026-01-31 19:12:25 +00:00
consumed + = ddp_world_size # Track actual consumption
else :
# No conversation fits - pad the remainder instead of cropping
# This ensures we never discard any tokens
content_len = len ( row )
row . extend ( [ bos_token ] * remaining ) # Pad with BOS tokens
2026-03-03 06:07:47 +05:30
mask_row . extend ( [ 0 ] * remaining )
2026-01-31 19:12:25 +00:00
padded = True
break # Row is now full (with padding)
# Track content length: full row if no padding, otherwise the length before padding
if padded :
row_lengths . append ( content_len )
else :
row_lengths . append ( row_capacity )
rows . append ( row [ : row_capacity ] )
2026-03-03 06:07:47 +05:30
mask_rows . append ( mask_row [ : row_capacity ] )
2026-01-31 19:12:25 +00:00
# Stopping condition to respect num_iterations, if given
it + = 1
if 0 < args . num_iterations < = it and split == " train " :
last_step = True
# Update progress tracking (based on consumed, not cursor, to account for buffering)
if split == " train " :
current_epoch = epoch
if args . num_iterations > 0 :
approx_progress = it / args . num_iterations
else :
approx_progress = consumed / dataset_size
# Trigger last_step when we've consumed enough (instead of when cursor wraps)
if consumed > = dataset_size :
last_step = True
# Build tensors
use_cuda = device_type == " cuda "
batch_tensor = torch . tensor ( rows , dtype = torch . long , pin_memory = use_cuda )
inputs = batch_tensor [ : , : - 1 ] . to ( device = device , dtype = torch . int32 , non_blocking = use_cuda )
targets = batch_tensor [ : , 1 : ] . to ( device = device , dtype = torch . int64 , non_blocking = use_cuda )
2026-03-03 06:07:47 +05:30
# Apply the loss mask from render_conversation (mask=1 for assistant completions,
# mask=0 for user prompts, BOS, special tokens, tool outputs). mask[1:] aligns
# with targets (shifted by 1). Unmasked positions get -1 (ignore_index).
mask_tensor = torch . tensor ( mask_rows , dtype = torch . int8 )
mask_targets = mask_tensor [ : , 1 : ] . to ( device = device )
targets [ mask_targets == 0 ] = - 1
2026-01-31 19:12:25 +00:00
# Mask out padding positions in targets (set to -1 = ignore_index)
# For each row, positions >= (content_length - 1) in targets should be masked
for i , content_len in enumerate ( row_lengths ) :
if content_len < row_capacity :
targets [ i , content_len - 1 : ] = - 1
yield inputs , targets
train_loader = sft_data_generator_bos_bestfit ( " train " )
build_val_loader = lambda : sft_data_generator_bos_bestfit ( " val " )
progress = 0 # will go from 0 to 1 over the course of the epoch
2025-10-13 06:49:24 -07:00
2026-02-16 14:41:53 +00:00
# Learning rate schedule (linear warmup, constant, linear warmdown)
# Same shape as base_train but uses progress (0→1) instead of absolute step counts,
# because SFT doesn't always know num_iterations in advance (dataset-driven stopping).
2026-01-31 19:12:25 +00:00
def get_lr_multiplier ( progress ) :
2026-02-16 14:41:53 +00:00
if progress < args . warmup_ratio :
return ( progress + 1e-8 ) / args . warmup_ratio
elif progress < = 1.0 - args . warmdown_ratio :
return 1.0
else :
decay = ( progress - ( 1.0 - args . warmdown_ratio ) ) / args . warmdown_ratio
return ( 1 - decay ) * 1.0 + decay * args . final_lr_frac
2025-10-13 06:49:24 -07:00
2026-01-31 19:12:25 +00:00
# Momentum scheduler for Muon optimizer
def get_muon_momentum ( it ) :
frac = min ( it / 300 , 1 )
momentum = ( 1 - frac ) * 0.85 + frac * 0.95
return momentum
# -----------------------------------------------------------------------------
# Training loop
x , y = next ( train_loader ) # prefetch the very first batch of data
min_val_bpb = float ( " inf " )
smooth_train_loss = 0 # EMA of training loss
ema_beta = 0.9 # EMA decay factor
total_training_time = 0 # total wall-clock time of training
2025-10-13 06:49:24 -07:00
step = 0
2026-01-31 19:12:25 +00:00
while True :
flops_so_far = num_flops_per_token * args . total_batch_size * step
# Synchronize last_step across all ranks to avoid hangs in the distributed setting
if ddp :
last_step_tensor = torch . tensor ( last_step , dtype = torch . int32 , device = device )
dist . all_reduce ( last_step_tensor , op = dist . ReduceOp . MAX )
last_step = bool ( last_step_tensor . item ( ) )
2025-10-13 06:49:24 -07:00
2026-01-31 19:12:25 +00:00
# once in a while: evaluate the val bpb (all ranks participate)
if last_step or ( args . eval_every > 0 and step % args . eval_every == 0 ) :
2025-10-13 06:49:24 -07:00
model . eval ( )
2025-11-20 15:14:56 -05:00
val_loader = build_val_loader ( )
2026-01-31 19:12:25 +00:00
eval_steps = args . eval_tokens / / ( args . device_batch_size * args . max_seq_len * ddp_world_size )
with autocast_ctx :
val_bpb = evaluate_bpb ( model , val_loader , eval_steps , token_bytes )
print0 ( f " Step { step : 05d } | Validation bpb: { val_bpb : .4f } " )
if val_bpb < min_val_bpb :
min_val_bpb = val_bpb
2025-10-13 06:49:24 -07:00
wandb_run . log ( {
" step " : step ,
2026-01-31 19:12:25 +00:00
" total_training_flops " : flops_so_far ,
" total_training_time " : total_training_time ,
" val/bpb " : val_bpb ,
2025-10-13 06:49:24 -07:00
} )
model . train ( )
2026-02-16 14:41:53 +00:00
# once in a while: estimate the ChatCORE metric (all ranks participate)
# use the original uncompiled model because the inputs keep changing shape
chatcore_results = { }
if args . chatcore_every > 0 and ( last_step or ( step > 0 and step % args . chatcore_every == 0 ) ) :
model . eval ( )
engine = Engine ( orig_model , tokenizer )
all_tasks = [ ' ARC-Easy ' , ' ARC-Challenge ' , ' MMLU ' , ' GSM8K ' , ' HumanEval ' , ' SpellingBee ' ]
categorical_tasks = { ' ARC-Easy ' , ' ARC-Challenge ' , ' MMLU ' }
baseline_accuracies = {
' ARC-Easy ' : 0.25 , ' ARC-Challenge ' : 0.25 , ' MMLU ' : 0.25 ,
' GSM8K ' : 0.0 , ' HumanEval ' : 0.0 , ' SpellingBee ' : 0.0 ,
}
task_results = { }
for task_name in all_tasks :
limit = args . chatcore_max_cat if task_name in categorical_tasks else args . chatcore_max_sample
max_problems = None if limit < 0 else limit # -1 means no limit
with autocast_ctx :
acc = run_chat_eval ( task_name , orig_model , tokenizer , engine ,
batch_size = args . device_batch_size , max_problems = max_problems )
task_results [ task_name ] = acc
print0 ( f " { task_name } : { 100 * acc : .2f } % " )
# Compute ChatCORE metrics (mean centered accuracy, ranges from 0=random to 1=perfect)
def centered_mean ( tasks ) :
return sum ( ( task_results [ t ] - baseline_accuracies [ t ] ) / ( 1.0 - baseline_accuracies [ t ] ) for t in tasks ) / len ( tasks )
chatcore = centered_mean ( all_tasks )
chatcore_cat = centered_mean ( categorical_tasks )
print0 ( f " Step { step : 05d } | ChatCORE: { chatcore : .4f } | ChatCORE_cat: { chatcore_cat : .4f } " )
wandb_run . log ( {
" step " : step ,
" total_training_flops " : flops_so_far ,
" chatcore_metric " : chatcore ,
" chatcore_cat " : chatcore_cat ,
* * { f " chatcore/ { task_name } " : acc for task_name , acc in task_results . items ( ) } ,
} )
model . train ( )
# save checkpoint at the end of the run (all ranks participate so each saves its optimizer shard)
if last_step :
2026-01-31 19:12:25 +00:00
output_dirname = args . model_tag if args . model_tag else f " d { depth } " # e.g. d12
2026-02-01 20:15:19 +00:00
checkpoint_dir = os . path . join ( base_dir , " chatsft_checkpoints " , output_dirname )
2026-01-31 19:12:25 +00:00
save_checkpoint (
checkpoint_dir ,
step ,
orig_model . state_dict ( ) ,
optimizer . state_dict ( ) ,
{
" step " : step ,
" val_bpb " : val_bpb , # loss at last step
" model_config " : {
" sequence_len " : args . max_seq_len ,
" vocab_size " : tokenizer . get_vocab_size ( ) ,
" n_layer " : depth ,
" n_head " : model . config . n_head ,
" n_kv_head " : model . config . n_kv_head ,
" n_embd " : model . config . n_embd ,
2026-02-01 20:58:44 +00:00
" window_pattern " : model . config . window_pattern ,
2026-01-31 19:12:25 +00:00
} ,
" user_config " : user_config , # inputs to the training script
2026-02-16 14:41:53 +00:00
} ,
rank = ddp_rank ,
2026-01-31 19:12:25 +00:00
)
2025-10-13 06:49:24 -07:00
if last_step :
break
2026-01-31 19:12:25 +00:00
# -------------------------------------------------------------------------
# single training step
2025-10-13 06:49:24 -07:00
# evaluate the gradient
2026-01-31 19:12:25 +00:00
synchronize ( )
t0 = time . time ( )
2025-10-13 06:49:24 -07:00
for micro_step in range ( grad_accum_steps ) :
with autocast_ctx :
2026-01-31 19:12:25 +00:00
loss = model ( x , y )
2025-10-13 06:49:24 -07:00
train_loss = loss . detach ( ) # for logging
loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
2026-01-31 19:12:25 +00:00
loss . backward ( )
x , y = next ( train_loader ) # prefetch the next batch while the GPU is busy with forward/backward
progress = max ( progress , approx_progress ) # only increase progress monotonically
# step the optimizer
lrm = get_lr_multiplier ( progress )
muon_momentum = get_muon_momentum ( step )
2026-01-29 00:50:50 +00:00
for group in optimizer . param_groups :
group [ " lr " ] = group [ " initial_lr " ] * lrm
2026-01-31 19:12:25 +00:00
if group [ ' kind ' ] == ' muon ' :
group [ " momentum " ] = muon_momentum
2026-01-29 00:50:50 +00:00
optimizer . step ( )
2025-10-13 06:49:24 -07:00
model . zero_grad ( set_to_none = True )
2026-01-31 19:12:25 +00:00
synchronize ( )
t1 = time . time ( )
dt = t1 - t0
# -------------------------------------------------------------------------
2025-10-13 06:49:24 -07:00
2026-01-31 19:12:25 +00:00
# State
2025-10-13 06:49:24 -07:00
step + = 1
2026-01-31 19:12:25 +00:00
# logging
smooth_train_loss = ema_beta * smooth_train_loss + ( 1 - ema_beta ) * train_loss . item ( ) # EMA the training loss
debiased_smooth_loss = smooth_train_loss / ( 1 - ema_beta * * ( step + 1 ) ) # debias the EMA
pct_done = 100 * progress
tok_per_sec = int ( args . total_batch_size / dt )
flops_per_sec = num_flops_per_token * args . total_batch_size / dt
2026-02-16 14:41:53 +00:00
mfu = 100 * flops_per_sec / ( gpu_peak_flops * ddp_world_size )
2026-01-31 19:12:25 +00:00
if step > 10 :
total_training_time + = dt # only count the time after the first 10 steps
print0 ( f " step { step : 05d } ( { pct_done : .2f } %) | loss: { debiased_smooth_loss : .6f } | lrm: { lrm : .2f } | dt: { dt * 1000 : .2f } ms | tok/sec: { tok_per_sec : , } | mfu: { mfu : .2f } | epoch: { current_epoch } | total time: { total_training_time / 60 : .2f } m " )
if step % 10 == 0 :
wandb_run . log ( {
2025-10-13 06:49:24 -07:00
" step " : step ,
2026-01-31 19:12:25 +00:00
" total_training_flops " : flops_so_far ,
" total_training_time " : total_training_time ,
" train/loss " : debiased_smooth_loss ,
" train/lrm " : lrm ,
" train/dt " : dt ,
" train/tok_per_sec " : tok_per_sec ,
" train/mfu " : mfu ,
" train/epoch " : current_epoch ,
} )
2026-02-16 14:41:53 +00:00
# The garbage collector spends ~500ms scanning for cycles quite frequently.
# We manually manage it to avoid these pauses during training.
if step == 1 :
gc . collect ( ) # manually collect a lot of garbage from setup
gc . freeze ( ) # freeze all currently surviving objects and exclude them from GC
gc . disable ( ) # disable GC entirely except:
elif step % 5000 == 0 : # every 5000 steps...
gc . collect ( ) # manually collect, just to be safe for very long runs
2026-01-31 19:12:25 +00:00
# print a few more stats
print0 ( f " Peak memory usage: { get_max_memory ( ) / 1024 / 1024 : .2f } MiB " )
print0 ( f " Total training time: { total_training_time / 60 : .2f } m " )
print0 ( f " Minimum validation bpb: { min_val_bpb : .4f } " )
2025-10-13 06:49:24 -07:00
# Log to report
2026-02-16 14:41:53 +00:00
from nanochat . report import get_report
get_report ( ) . log ( section = " SFT " , data = [
user_config , # CLI args
{ # stats about the training setup
" Number of iterations " : step ,
" DDP world size " : ddp_world_size ,
} ,
{ # stats about training outcomes
" Minimum validation bpb " : min_val_bpb ,
}
] )
2026-01-31 19:12:25 +00:00
# cleanup
wandb_run . finish ( ) # wandb run finish
2025-10-13 06:49:24 -07:00
compute_cleanup ( )