2025-10-13 06:49:24 -07:00
"""
2026-01-07 22:11:52 +00:00
Train model. From root directory of the project, run as:
2025-10-13 06:49:24 -07:00
2026-01-25 18:59:51 +00:00
python -m scripts.base_train
2025-10-13 06:49:24 -07:00
or distributed as:
2026-01-25 18:59:51 +00:00
torchrun --nproc_per_node=8 -m scripts.base_train
2025-10-16 16:14:38 +00:00
2025-10-16 15:46:18 -07:00
If you are only on CPU/Macbook, you ' ll want to train a much much smaller LLM. Example:
2026-01-13 22:45:27 +00:00
python -m scripts.base_train --depth=4 --max-seq-len=512 --device-batch-size=1 --eval-tokens=512 --core-metric-every=-1 --total-batch-size=512 --num-iterations=20
2025-10-13 06:49:24 -07:00
"""
2026-02-02 01:44:30 +00:00
import gc
2025-10-13 06:49:24 -07:00
import os
2025-12-28 03:32:46 +00:00
os . environ [ " PYTORCH_ALLOC_CONF " ] = " expandable_segments:True "
2026-01-04 19:14:23 +00:00
import argparse
2025-10-13 06:49:24 -07:00
import time
2025-10-16 15:46:18 -07:00
from contextlib import nullcontext
2025-10-13 06:49:24 -07:00
import wandb
import torch
from nanochat . gpt import GPT , GPTConfig
2026-01-13 20:05:47 +00:00
from nanochat . dataloader import tokenizing_distributed_data_loader_bos_bestfit , tokenizing_distributed_data_loader_with_state_bos_bestfit
2026-01-17 03:16:12 +00:00
from nanochat . common import compute_init , compute_cleanup , print0 , DummyWandb , print_banner , get_base_dir , autodetect_device_type , get_peak_flops
2025-10-13 06:49:24 -07:00
from nanochat . tokenizer import get_tokenizer , get_token_bytes
2025-11-13 15:34:40 +00:00
from nanochat . checkpoint_manager import save_checkpoint , load_checkpoint
2025-10-13 06:49:24 -07:00
from nanochat . loss_eval import evaluate_bpb
from nanochat . engine import Engine
2026-01-16 17:37:51 +00:00
from nanochat . flash_attention import HAS_FA3
2026-02-01 05:03:44 +00:00
from scripts . base_eval import evaluate_core
2025-10-13 06:49:24 -07:00
print_banner ( )
# -----------------------------------------------------------------------------
2026-01-04 19:14:23 +00:00
# CLI arguments
parser = argparse . ArgumentParser ( description = " Pretrain base model " )
# Logging
parser . add_argument ( " --run " , type = str , default = " dummy " , help = " wandb run name ( ' dummy ' disables wandb logging) " )
2025-10-16 16:14:38 +00:00
# Runtime
2026-01-13 22:45:27 +00:00
parser . add_argument ( " --device-type " , type = str , default = " " , help = " cuda|cpu|mps (empty = autodetect) " )
2025-10-13 06:49:24 -07:00
# Model architecture
2026-01-04 19:14:23 +00:00
parser . add_argument ( " --depth " , type = int , default = 20 , help = " depth of the Transformer model " )
2026-01-13 22:45:27 +00:00
parser . add_argument ( " --aspect-ratio " , type = int , default = 64 , help = " model_dim = depth * aspect_ratio " )
parser . add_argument ( " --head-dim " , type = int , default = 128 , help = " target head dimension for attention " )
parser . add_argument ( " --max-seq-len " , type = int , default = 2048 , help = " max context length " )
parser . add_argument ( " --window-pattern " , type = str , default = " SSSL " , help = " sliding window pattern tiled across layers: L=full, S=half context (e.g. ' SSL ' ) " )
2026-01-04 19:14:23 +00:00
# Training horizon (only one used, in order of precedence)
2026-01-13 22:45:27 +00:00
parser . add_argument ( " --num-iterations " , type = int , default = - 1 , help = " explicit number of optimization steps (-1 = disable) " )
parser . add_argument ( " --target-flops " , type = float , default = - 1.0 , help = " calculate num_iterations to reach target_flops (-1 = disable) " )
2026-01-27 22:31:17 +00:00
parser . add_argument ( " --target-param-data-ratio " , type = float , default = 10.5 , help = " calculate num_iterations to maintain data:param ratio (Chinchilla=20, -1 = disable) " )
2025-10-13 06:49:24 -07:00
# Optimization
2026-01-13 22:45:27 +00:00
parser . add_argument ( " --device-batch-size " , type = int , default = 32 , help = " per-device batch size " )
parser . add_argument ( " --total-batch-size " , type = int , default = 524288 , help = " total batch size in tokens " )
parser . add_argument ( " --embedding-lr " , type = float , default = 0.3 , help = " learning rate for embedding parameters (Adam) " )
parser . add_argument ( " --unembedding-lr " , type = float , default = 0.004 , help = " learning rate for unembedding parameters (Adam) " )
parser . add_argument ( " --weight-decay " , type = float , default = 0.2 , help = " cautious weight decay for the Muon optimizer (for weights) " )
parser . add_argument ( " --matrix-lr " , type = float , default = 0.02 , help = " learning rate for matrix parameters (Muon) " )
parser . add_argument ( " --scalar-lr " , type = float , default = 0.5 , help = " learning rate for scalars (resid_lambdas, x0_lambdas) " )
parser . add_argument ( " --adam-beta1 " , type = float , default = 0.8 , help = " Adam beta1 for embedding/unembedding " )
parser . add_argument ( " --adam-beta2 " , type = float , default = 0.95 , help = " Adam beta2 for embedding/unembedding " )
parser . add_argument ( " --warmup-ratio " , type = float , default = 0.0 , help = " ratio of iterations for LR warmup " )
2026-01-31 01:08:44 +00:00
parser . add_argument ( " --warmdown-ratio " , type = float , default = 0.5 , help = " ratio of iterations for LR warmdown " )
2026-01-13 22:45:27 +00:00
parser . add_argument ( " --final-lr-frac " , type = float , default = 0.0 , help = " final LR as fraction of initial LR " )
parser . add_argument ( " --resume-from-step " , type = int , default = - 1 , help = " resume training from this step (-1 = disable) " )
2025-10-13 06:49:24 -07:00
# Evaluation
2026-01-13 22:45:27 +00:00
parser . add_argument ( " --eval-every " , type = int , default = 250 , help = " evaluate val bpb every N steps (-1 = disable) " )
parser . add_argument ( " --eval-tokens " , type = int , default = 20 * 524288 , help = " number of tokens to evaluate val loss on " )
parser . add_argument ( " --core-metric-every " , type = int , default = 2000 , help = " evaluate CORE metric every N steps (-1 = disable) " )
parser . add_argument ( " --core-metric-max-per-task " , type = int , default = 500 , help = " examples per task for CORE metric " )
parser . add_argument ( " --sample-every " , type = int , default = 2000 , help = " sample from model every N steps (-1 = disable) " )
parser . add_argument ( " --save-every " , type = int , default = - 1 , help = " save checkpoints every N steps (-1 = only at end) " )
2025-10-13 06:49:24 -07:00
# Output
2026-01-13 22:45:27 +00:00
parser . add_argument ( " --model-tag " , type = str , default = None , help = " override model tag for checkpoint directory name " )
2026-01-04 19:14:23 +00:00
args = parser . parse_args ( )
user_config = vars ( args ) . copy ( ) # for logging
2025-10-13 06:49:24 -07:00
# -----------------------------------------------------------------------------
# Compute init
2026-01-04 19:14:23 +00:00
device_type = autodetect_device_type ( ) if args . device_type == " " else args . device_type
2025-10-16 16:14:38 +00:00
ddp , ddp_rank , ddp_local_rank , ddp_world_size , device = compute_init ( device_type )
2025-10-13 06:49:24 -07:00
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
2025-10-16 15:46:18 -07:00
autocast_ctx = torch . amp . autocast ( device_type = device_type , dtype = torch . bfloat16 ) if device_type == " cuda " else nullcontext ( )
2025-10-16 16:14:38 +00:00
synchronize = torch . cuda . synchronize if device_type == " cuda " else lambda : None
get_max_memory = torch . cuda . max_memory_allocated if device_type == " cuda " else lambda : 0
2026-01-17 03:16:12 +00:00
if device_type == " cuda " :
gpu_device_name = torch . cuda . get_device_name ( 0 )
gpu_peak_flops = get_peak_flops ( gpu_device_name )
print0 ( f " GPU: { gpu_device_name } | Peak FLOPS (BF16): { gpu_peak_flops : .2e } " )
else :
gpu_peak_flops = float ( ' inf ' ) # MFU not meaningful for CPU/MPS
2025-10-13 06:49:24 -07:00
# wandb logging init
2026-01-04 19:14:23 +00:00
use_dummy_wandb = args . run == " dummy " or not master_process
wandb_run = DummyWandb ( ) if use_dummy_wandb else wandb . init ( project = " nanochat " , name = args . run , config = user_config )
2025-10-13 06:49:24 -07:00
2026-01-16 17:37:51 +00:00
# Flash Attention status
if HAS_FA3 :
print0 ( " ✓ Using Flash Attention 3 (Hopper GPU detected), efficient, new and awesome. " )
else :
print0 ( " ! " * 80 )
print0 ( " WARNING: Flash Attention 3 not available, using PyTorch SDPA fallback " )
print0 ( " WARNING: Training will be less efficient without FA3 " )
if args . window_pattern != " L " :
print0 ( f " WARNING: SDPA has no support for sliding window attention (window_pattern= ' { args . window_pattern } ' ). Your GPU utilization will be terrible. " )
print0 ( " WARNING: Recommend using --window-pattern L for full context attention without alternating sliding window patterns. " )
print0 ( " ! " * 80 )
2025-10-13 06:49:24 -07:00
# Tokenizer will be useful for evaluation, also we need the vocab size
tokenizer = get_tokenizer ( )
token_bytes = get_token_bytes ( device = device )
vocab_size = tokenizer . get_vocab_size ( )
print0 ( f " Vocab size: { vocab_size : , } " )
# Model kwargs are derived from the desired depth of the model
2026-01-18 00:07:08 +00:00
# We nudge model_dim up to the nearest multiple of head_dim to ensure clean division
# (FA3 requires head_dim divisible by 8, and this guarantees head_dim == args.head_dim exactly)
# (For very small depths, this gives a slight "unfair" advantage to models with odd depths)
2026-01-04 19:14:23 +00:00
num_layers = args . depth
2026-01-18 00:07:08 +00:00
base_dim = args . depth * args . aspect_ratio
model_dim = ( ( base_dim + args . head_dim - 1 ) / / args . head_dim ) * args . head_dim
num_heads = model_dim / / args . head_dim
2025-10-21 18:07:33 +00:00
num_kv_heads = num_heads # default is 1:1 GQA (Group Query Attention) ratio (i.e. GQA is disabled)
2026-01-18 00:07:08 +00:00
head_dim = model_dim / / num_heads
2025-10-13 06:49:24 -07:00
print0 ( f " num_layers: { num_layers } " )
2026-01-18 00:07:08 +00:00
print0 ( f " model_dim: { model_dim } (base: { base_dim } , nudge: { model_dim - base_dim : +d } ) " )
2025-10-13 06:49:24 -07:00
print0 ( f " num_heads: { num_heads } " )
2026-01-18 00:07:08 +00:00
print0 ( f " head_dim: { head_dim } " )
2025-10-13 06:49:24 -07:00
print0 ( f " num_kv_heads: { num_kv_heads } " )
# Optimizer / data / training length related hyperparameters
# figure out the needed gradient accumulation to reach the desired total batch size
2026-01-04 19:14:23 +00:00
tokens_per_fwdbwd = args . device_batch_size * args . max_seq_len # tokens per iteration for a single rank
2025-10-13 06:49:24 -07:00
world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks
2026-01-04 19:14:23 +00:00
assert args . total_batch_size % world_tokens_per_fwdbwd == 0
grad_accum_steps = args . total_batch_size / / world_tokens_per_fwdbwd
print0 ( f " Tokens / micro-batch / rank: { args . device_batch_size } x { args . max_seq_len } = { tokens_per_fwdbwd : , } " )
2025-10-13 06:49:24 -07:00
print0 ( f " Tokens / micro-batch: { world_tokens_per_fwdbwd : , } " )
2026-01-04 19:14:23 +00:00
print0 ( f " Total batch size { args . total_batch_size : , } => gradient accumulation steps: { grad_accum_steps } " )
2025-11-13 15:34:40 +00:00
2026-01-07 22:11:52 +00:00
# Batch size scaling for learning rates (hyperparameters were tuned at reference batch size 2^19)
batch_lr_scale = 1.0
reference_batch_size = 2 * * 19
batch_ratio = args . total_batch_size / reference_batch_size
if batch_ratio != 1.0 :
# SGD: linear scaling with batch size is standard (not used in nanochat)
# AdamW: sqrt scaling is standard
# Muon: sqrt scaling is an assumption - not fully studied, but it's a second-order-ish optimizer
batch_lr_scale = batch_ratio * * 0.5
print0 ( f " Scaling LRs by { batch_lr_scale : .4f } for batch size { args . total_batch_size : , } (reference: { reference_batch_size : , } ) " )
2026-01-11 16:56:59 +00:00
# Weight decay is tuned at d12 and its scaling seems to be \propto 1/channels^2 (or equivalently, \propto 1/depth^2 due to constant aspect ratio)
weight_decay_scaled = args . weight_decay * ( 12 / args . depth ) * * 2
if args . depth != 12 :
print0 ( f " Scaling weight decay from { args . weight_decay : .6f } to { weight_decay_scaled : .6f } for depth { args . depth } " )
2025-10-13 06:49:24 -07:00
# -----------------------------------------------------------------------------
# Initialize the Model
2025-11-13 15:34:40 +00:00
# Create a new model with random weights
2026-01-11 21:49:54 +00:00
model_config_kwargs = dict ( sequence_len = args . max_seq_len , vocab_size = vocab_size , n_layer = num_layers , n_head = num_heads , n_kv_head = num_kv_heads , n_embd = model_dim , window_pattern = args . window_pattern )
2025-10-13 06:49:24 -07:00
with torch . device ( " meta " ) :
2026-01-01 21:14:26 +00:00
# All tensors are created as meta tensors (they have shape/dtype but no data)
2025-10-13 06:49:24 -07:00
model_config = GPTConfig ( * * model_config_kwargs )
model = GPT ( model_config )
2026-01-01 21:14:26 +00:00
model . to_empty ( device = device ) # All tensors get storage on target device but with uninitialized (garbage) data
model . init_weights ( ) # All tensors get initialized
2025-11-13 15:34:40 +00:00
# If we are resuming, overwrite the model parameters with those of the checkpoint
base_dir = get_base_dir ( )
2026-01-04 19:14:23 +00:00
output_dirname = args . model_tag if args . model_tag else f " d { args . depth } " # e.g. d12
2025-11-13 15:34:40 +00:00
checkpoint_dir = os . path . join ( base_dir , " base_checkpoints " , output_dirname )
2026-01-04 19:14:23 +00:00
resuming = args . resume_from_step != - 1
2025-11-13 15:34:40 +00:00
if resuming :
2026-01-04 19:14:23 +00:00
print0 ( f " Resuming optimization from step { args . resume_from_step } " )
model_data , optimizer_data , meta_data = load_checkpoint ( checkpoint_dir , args . resume_from_step , device , load_optimizer = True , rank = ddp_rank )
2025-11-13 15:34:40 +00:00
model . load_state_dict ( model_data , strict = True , assign = True )
del model_data # free up this memory after the copy
orig_model = model # original, uncompiled model, for saving raw model state_dict and for inference/evaluation (because the shapes may change shape)
model = torch . compile ( model , dynamic = False ) # the inputs to model will never change shape so dynamic=False is safe
2026-01-27 22:31:17 +00:00
# Detailed parameter counts
param_counts = orig_model . num_scaling_params ( )
print0 ( f " Parameter counts: " )
for key , value in param_counts . items ( ) :
print0 ( f " { key : 24s } : { value : , } " )
num_params = param_counts [ ' total ' ]
num_scaling_params = param_counts [ ' transformer_matrices ' ] + param_counts [ ' lm_head ' ] # determined to give the cleanest scaling laws, see dev/LOG.md Jan 27, 2026
2025-10-13 06:49:24 -07:00
num_flops_per_token = model . estimate_flops ( )
print0 ( f " Estimated FLOPs per token: { num_flops_per_token : e } " )
# Calculate number of iterations. Either it is given, or from target flops, or from target data:param ratio (in that order)
2026-01-04 19:14:23 +00:00
assert args . num_iterations > 0 or args . target_param_data_ratio > 0 or args . target_flops > 0
if args . num_iterations > 0 :
num_iterations = args . num_iterations
2025-10-13 06:49:24 -07:00
print0 ( f " Using user-provided number of iterations: { num_iterations : , } " )
2026-01-04 19:14:23 +00:00
elif args . target_flops > 0 :
2025-10-13 06:49:24 -07:00
# calculate the number of iterations from the target flops
2026-01-04 19:14:23 +00:00
num_iterations = round ( args . target_flops / ( num_flops_per_token * args . total_batch_size ) )
2025-10-13 06:49:24 -07:00
print0 ( f " Calculated number of iterations from target FLOPs: { num_iterations : , } " )
2026-01-04 19:14:23 +00:00
elif args . target_param_data_ratio > 0 :
2026-01-07 22:11:52 +00:00
# calculate the number of iterations from the target param data ratio (use scaling params per Kaplan et al.)
2026-01-27 22:31:17 +00:00
target_tokens = int ( args . target_param_data_ratio * num_scaling_params )
2026-01-04 19:14:23 +00:00
num_iterations = target_tokens / / args . total_batch_size
2025-10-13 06:49:24 -07:00
print0 ( f " Calculated number of iterations from target data:param ratio: { num_iterations : , } " )
else :
raise ValueError ( " No training horizon specified " )
2026-01-04 19:14:23 +00:00
total_tokens = args . total_batch_size * num_iterations
2025-10-13 06:49:24 -07:00
print0 ( f " Total number of training tokens: { total_tokens : , } " )
2026-01-27 22:31:17 +00:00
print0 ( f " Tokens : Scaling params ratio: { args . total_batch_size * num_iterations / num_scaling_params : .2f } " ) # Chinchilla is ~20
2025-10-13 06:49:24 -07:00
print0 ( f " Total training FLOPs estimate: { num_flops_per_token * total_tokens : e } " )
# -----------------------------------------------------------------------------
2026-01-29 00:50:50 +00:00
# Initialize the Optimizer (combined MuonAdamW: Muon for matrix params, AdamW for rest)
2026-01-07 22:11:52 +00:00
adam_betas = ( args . adam_beta1 , args . adam_beta2 )
2026-01-29 00:50:50 +00:00
optimizer = model . setup_optimizer (
2026-01-07 22:11:52 +00:00
unembedding_lr = args . unembedding_lr * batch_lr_scale ,
embedding_lr = args . embedding_lr * batch_lr_scale ,
matrix_lr = args . matrix_lr * batch_lr_scale ,
2026-01-11 16:56:59 +00:00
weight_decay = weight_decay_scaled ,
2026-01-07 22:11:52 +00:00
adam_betas = adam_betas ,
2026-01-11 18:47:35 +00:00
scalar_lr = args . scalar_lr * batch_lr_scale ,
2026-01-07 22:11:52 +00:00
)
2025-10-13 06:49:24 -07:00
2025-11-13 15:34:40 +00:00
if resuming :
2026-01-29 00:50:50 +00:00
optimizer . load_state_dict ( optimizer_data )
del optimizer_data
2025-11-13 15:34:40 +00:00
# -----------------------------------------------------------------------------
2025-10-13 06:49:24 -07:00
# Initialize the DataLoaders for train/val
2025-11-13 15:34:40 +00:00
dataloader_resume_state_dict = None if not resuming else meta_data [ " dataloader_state_dict " ]
2026-01-13 20:05:47 +00:00
train_loader = tokenizing_distributed_data_loader_with_state_bos_bestfit ( tokenizer , args . device_batch_size , args . max_seq_len , split = " train " , device = device , resume_state_dict = dataloader_resume_state_dict )
build_val_loader = lambda : tokenizing_distributed_data_loader_bos_bestfit ( tokenizer , args . device_batch_size , args . max_seq_len , split = " val " , device = device )
2025-11-13 15:34:40 +00:00
x , y , dataloader_state_dict = next ( train_loader ) # kick off load of the very first batch of data
2025-10-13 06:49:24 -07:00
# -----------------------------------------------------------------------------
# Set up hyperparameter schedulers
# Learning rate scheduler
def get_lr_multiplier ( it ) :
2026-01-04 19:14:23 +00:00
warmup_iters = round ( args . warmup_ratio * num_iterations )
warmdown_iters = round ( args . warmdown_ratio * num_iterations )
2025-10-13 06:49:24 -07:00
if it < warmup_iters :
return ( it + 1 ) / warmup_iters
elif it < = num_iterations - warmdown_iters :
return 1.0
else :
progress = ( num_iterations - it ) / warmdown_iters
2026-01-04 19:14:23 +00:00
return progress * 1.0 + ( 1 - progress ) * args . final_lr_frac
2025-10-13 06:49:24 -07:00
# Momentum scheduler for Muon optimizer
def get_muon_momentum ( it ) :
frac = min ( it / 300 , 1 )
momentum = ( 1 - frac ) * 0.85 + frac * 0.95
return momentum
2026-01-11 16:56:59 +00:00
# Weight decay scheduler for Muon optimizer (linear to zero over the course of training)
def get_weight_decay ( it ) :
return weight_decay_scaled * ( 1 - it / num_iterations )
2025-11-13 15:34:40 +00:00
# -----------------------------------------------------------------------------
# Loop state (variables updated by the training loop)
if not resuming :
step = 0
2026-01-05 00:38:09 +00:00
val_bpb = None # will be set if eval_every > 0
2025-11-13 15:34:40 +00:00
min_val_bpb = float ( " inf " )
smooth_train_loss = 0 # EMA of training loss
total_training_time = 0 # total wall-clock time of training
else :
step = meta_data [ " step " ]
loop_state = meta_data [ " loop_state " ]
2025-11-22 11:04:20 +08:00
val_bpb = meta_data [ " val_bpb " ]
2025-11-13 15:34:40 +00:00
min_val_bpb = loop_state [ " min_val_bpb " ]
smooth_train_loss = loop_state [ " smooth_train_loss " ]
total_training_time = loop_state [ " total_training_time " ]
2025-10-13 06:49:24 -07:00
# -----------------------------------------------------------------------------
# Training loop
2025-11-13 15:34:40 +00:00
while True :
last_step = step == num_iterations # loop runs num_iterations+1 times so that we can eval/save at the end
2026-01-04 19:14:23 +00:00
flops_so_far = num_flops_per_token * args . total_batch_size * step
2025-10-13 06:49:24 -07:00
# once in a while: evaluate the val bpb (all ranks participate)
2026-01-05 00:38:09 +00:00
if args . eval_every > 0 and ( last_step or step % args . eval_every == 0 ) :
2025-10-13 06:49:24 -07:00
model . eval ( )
val_loader = build_val_loader ( )
2026-01-04 19:14:23 +00:00
eval_steps = args . eval_tokens / / ( args . device_batch_size * args . max_seq_len * ddp_world_size )
2025-10-13 06:49:24 -07:00
with autocast_ctx :
val_bpb = evaluate_bpb ( model , val_loader , eval_steps , token_bytes )
2026-01-11 16:56:59 +00:00
print0 ( f " Step { step : 05d } | Validation bpb: { val_bpb : .6f } " )
2025-10-13 06:49:24 -07:00
if val_bpb < min_val_bpb :
min_val_bpb = val_bpb
wandb_run . log ( {
" step " : step ,
" total_training_flops " : flops_so_far ,
" total_training_time " : total_training_time ,
" val/bpb " : val_bpb ,
} )
model . train ( )
# once in a while: estimate the CORE metric (all ranks participate)
# use the original uncompiled model because the inputs keep changing shape
2025-10-16 15:46:18 -07:00
results = { }
2026-01-04 19:14:23 +00:00
if args . core_metric_every > 0 and ( last_step or ( step > 0 and step % args . core_metric_every == 0 ) ) :
2025-10-13 06:49:24 -07:00
model . eval ( )
with autocast_ctx :
2026-02-01 05:03:44 +00:00
results = evaluate_core ( orig_model , tokenizer , device , max_per_task = args . core_metric_max_per_task )
2025-10-13 06:49:24 -07:00
print0 ( f " Step { step : 05d } | CORE metric: { results [ ' core_metric ' ] : .4f } " )
wandb_run . log ( {
" step " : step ,
" total_training_flops " : flops_so_far ,
" core_metric " : results [ " core_metric " ] ,
" centered_results " : results [ " centered_results " ] ,
} )
model . train ( )
# once in a while: sample from the model (only on master process)
# use the original uncompiled model because the inputs keep changing shape
2026-01-05 00:38:09 +00:00
if args . sample_every > 0 and master_process and ( last_step or ( step > 0 and step % args . sample_every == 0 ) ) :
2025-10-13 06:49:24 -07:00
model . eval ( )
prompts = [
" The capital of France is " ,
" The chemical symbol of gold is " ,
" If yesterday was Friday, then tomorrow will be " ,
" The opposite of hot is " ,
" The planets of the solar system are: " ,
" My favorite color is " ,
" If 5*x + 3 = 13, then x is " ,
]
2025-10-20 00:05:09 +00:00
engine = Engine ( orig_model , tokenizer ) # use orig_model to avoid recompilation
2025-10-13 06:49:24 -07:00
for prompt in prompts :
tokens = tokenizer ( prompt , prepend = " <|bos|> " )
with autocast_ctx :
sample , _ = engine . generate_batch ( tokens , num_samples = 1 , max_tokens = 16 , temperature = 0 )
print0 ( tokenizer . decode ( sample [ 0 ] ) )
model . train ( )
2025-11-13 15:34:40 +00:00
# save checkpoint: at the end of the run, or every save_every steps, except at the first step or the resume step
2026-01-04 19:14:23 +00:00
if last_step or ( step > 0 and step != args . resume_from_step and args . save_every > 0 and step % args . save_every == 0 ) :
2025-10-13 06:49:24 -07:00
save_checkpoint (
checkpoint_dir ,
step ,
2025-11-13 15:34:40 +00:00
orig_model . state_dict ( ) , # model parameters
2026-01-29 00:50:50 +00:00
optimizer . state_dict ( ) , # optimizer state
2025-11-13 15:34:40 +00:00
{ # metadata saved as json
2025-10-13 06:49:24 -07:00
" step " : step ,
" val_bpb " : val_bpb , # loss at last step
" model_config " : model_config_kwargs ,
" user_config " : user_config , # inputs to the training script
2026-01-04 19:14:23 +00:00
" device_batch_size " : args . device_batch_size ,
" max_seq_len " : args . max_seq_len ,
2025-11-13 15:34:40 +00:00
" dataloader_state_dict " : dataloader_state_dict ,
" loop_state " : { # all loop state (other than step) so that we can resume training
" min_val_bpb " : min_val_bpb ,
" smooth_train_loss " : smooth_train_loss ,
" total_training_time " : total_training_time ,
} ,
} ,
rank = ddp_rank ,
2025-10-13 06:49:24 -07:00
)
2025-11-13 15:34:40 +00:00
# termination conditions (TODO: possibly also add loss explosions etc.)
2025-10-13 06:49:24 -07:00
if last_step :
break
# -------------------------------------------------------------------------
# single training step
# evaluate the gradient
2025-10-16 16:14:38 +00:00
synchronize ( )
2025-10-13 06:49:24 -07:00
t0 = time . time ( )
for micro_step in range ( grad_accum_steps ) :
with autocast_ctx :
loss = model ( x , y )
train_loss = loss . detach ( ) # for logging
loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
loss . backward ( )
2025-11-13 15:34:40 +00:00
x , y , dataloader_state_dict = next ( train_loader ) # prefetch the next batch while the GPU is busy with forward/backward
2026-01-29 00:50:50 +00:00
# step the optimizer
2025-10-13 06:49:24 -07:00
lrm = get_lr_multiplier ( step )
muon_momentum = get_muon_momentum ( step )
2026-01-11 16:56:59 +00:00
muon_weight_decay = get_weight_decay ( step )
2026-01-29 00:50:50 +00:00
for group in optimizer . param_groups :
group [ " lr " ] = group [ " initial_lr " ] * lrm
if group [ ' kind ' ] == ' muon ' :
group [ " momentum " ] = muon_momentum
group [ " weight_decay " ] = muon_weight_decay
optimizer . step ( )
2025-10-13 06:49:24 -07:00
model . zero_grad ( set_to_none = True )
2026-01-15 23:30:11 +00:00
train_loss_f = train_loss . item ( ) # .item() is a CPU-GPU sync point
2025-10-16 16:14:38 +00:00
synchronize ( )
2025-10-13 06:49:24 -07:00
t1 = time . time ( )
dt = t1 - t0
# -------------------------------------------------------------------------
2026-01-15 23:30:11 +00:00
# logging (CPU action only)
2025-11-13 15:34:40 +00:00
ema_beta = 0.9 # EMA decay factor for some smoothing just for nicer logging
2026-01-15 23:30:11 +00:00
smooth_train_loss = ema_beta * smooth_train_loss + ( 1 - ema_beta ) * train_loss_f # EMA the training loss
2025-10-13 06:49:24 -07:00
debiased_smooth_loss = smooth_train_loss / ( 1 - ema_beta * * ( step + 1 ) ) # debias the EMA
pct_done = 100 * step / num_iterations
2026-01-04 19:14:23 +00:00
tok_per_sec = int ( args . total_batch_size / dt )
flops_per_sec = num_flops_per_token * args . total_batch_size / dt
2026-01-17 03:16:12 +00:00
mfu = 100 * flops_per_sec / ( gpu_peak_flops * ddp_world_size )
2025-10-13 06:49:24 -07:00
if step > 10 :
total_training_time + = dt # only count the time after the first 10 steps
2026-01-05 00:38:09 +00:00
# Calculate ETA based on average time per step (excluding first 10 steps)
steps_done = step - 10
if steps_done > 0 :
avg_time_per_step = total_training_time / steps_done
remaining_steps = num_iterations - step
eta_seconds = remaining_steps * avg_time_per_step
eta_str = f " | eta: { eta_seconds / 60 : .1f } m "
else :
eta_str = " "
2026-01-13 20:05:47 +00:00
epoch = dataloader_state_dict [ " epoch " ]
print0 ( f " step { step : 05d } / { num_iterations : 05d } ( { pct_done : .2f } %) | loss: { debiased_smooth_loss : .6f } | lrm: { lrm : .2f } | dt: { dt * 1000 : .2f } ms | tok/sec: { tok_per_sec : , } | mfu: { mfu : .2f } | epoch: { epoch } | total time: { total_training_time / 60 : .2f } m { eta_str } " )
2025-10-13 06:49:24 -07:00
if step % 100 == 0 :
2025-11-05 21:08:30 +00:00
log_data = {
2025-10-13 06:49:24 -07:00
" step " : step ,
" total_training_flops " : flops_so_far ,
" total_training_time " : total_training_time ,
" train/loss " : debiased_smooth_loss ,
" train/lrm " : lrm ,
" train/dt " : dt ,
" train/tok_per_sec " : tok_per_sec ,
" train/mfu " : mfu ,
2026-01-13 20:05:47 +00:00
" train/epoch " : epoch ,
2025-11-05 21:08:30 +00:00
}
wandb_run . log ( log_data )
2025-10-13 06:49:24 -07:00
2025-11-13 15:34:40 +00:00
# state update
2026-02-02 01:44:30 +00:00
first_step_of_run = ( step == 0 ) or ( resuming and step == args . resume_from_step )
2025-11-13 15:34:40 +00:00
step + = 1
2026-02-02 01:44:30 +00:00
# The garbage collector is sadly a little bit overactive and for some poorly understood reason,
# it spends ~500ms scanning for cycles quite frequently, just to end up cleaning up very few tiny objects each time.
# So we manually manage and help it out here
if first_step_of_run :
gc . collect ( ) # manually collect a lot of garbage from setup
gc . freeze ( ) # immediately freeze all currently surviving objects and exclude them from GC
gc . disable ( ) # nuclear intervention here: disable GC entirely except:
elif step % 5000 == 0 : # every 5000 steps...
gc . collect ( ) # manually collect, just to be safe for very, very long runs
2025-10-13 06:49:24 -07:00
# print a few more stats
2025-10-16 16:14:38 +00:00
print0 ( f " Peak memory usage: { get_max_memory ( ) / 1024 / 1024 : .2f } MiB " )
2025-10-13 06:49:24 -07:00
print0 ( f " Total training time: { total_training_time / 60 : .2f } m " )
2026-01-05 00:38:09 +00:00
if val_bpb is not None :
2026-01-11 16:56:59 +00:00
print0 ( f " Minimum validation bpb: { min_val_bpb : .6f } " )
2025-10-13 06:49:24 -07:00
# Log to report
from nanochat . report import get_report
get_report ( ) . log ( section = " Base model training " , data = [
user_config , # CLI args
{ # stats about the training setup
" Number of parameters " : num_params ,
" Number of FLOPs per token " : f " { num_flops_per_token : e } " ,
" Calculated number of iterations " : num_iterations ,
" Number of training tokens " : total_tokens ,
2026-01-27 22:31:17 +00:00
" Tokens : Scaling params ratio " : args . total_batch_size * num_iterations / num_scaling_params ,
2025-10-13 06:49:24 -07:00
" DDP world size " : ddp_world_size ,
2026-01-04 19:14:23 +00:00
" warmup_ratio " : args . warmup_ratio ,
" warmdown_ratio " : args . warmdown_ratio ,
" final_lr_frac " : args . final_lr_frac ,
2025-10-13 06:49:24 -07:00
} ,
{ # stats about training outcomes
2026-01-05 00:38:09 +00:00
" Minimum validation bpb " : min_val_bpb if val_bpb is not None else None ,
2025-10-13 06:49:24 -07:00
" Final validation bpb " : val_bpb ,
2025-10-16 15:46:18 -07:00
" CORE metric estimate " : results . get ( " core_metric " , None ) ,
2025-10-13 06:49:24 -07:00
" MFU % " : f " { mfu : .2f } % " ,
" Total training flops " : f " { flops_so_far : e } " ,
" Total training time " : f " { total_training_time / 60 : .2f } m " ,
2025-10-16 16:14:38 +00:00
" Peak memory usage " : f " { get_max_memory ( ) / 1024 / 1024 : .2f } MiB " ,
2025-10-13 06:49:24 -07:00
}
] )
# cleanup
wandb_run . finish ( ) # wandb run finish
compute_cleanup ( )