2025-10-13 06:49:24 -07:00
"""
2026-01-07 22:11:52 +00:00
Train model. From root directory of the project, run as:
2025-10-13 06:49:24 -07:00
2026-01-25 18:59:51 +00:00
python -m scripts.base_train
2025-10-13 06:49:24 -07:00
or distributed as:
2026-01-25 18:59:51 +00:00
torchrun --nproc_per_node=8 -m scripts.base_train
2025-10-16 16:14:38 +00:00
2025-10-16 15:46:18 -07:00
If you are only on CPU/Macbook, you ' ll want to train a much much smaller LLM. Example:
2026-01-13 22:45:27 +00:00
python -m scripts.base_train --depth=4 --max-seq-len=512 --device-batch-size=1 --eval-tokens=512 --core-metric-every=-1 --total-batch-size=512 --num-iterations=20
2025-10-13 06:49:24 -07:00
"""
import os
2025-12-28 03:32:46 +00:00
os . environ [ " PYTORCH_ALLOC_CONF " ] = " expandable_segments:True "
2026-02-05 19:40:37 +00:00
import gc
import json
2025-10-13 06:49:24 -07:00
import time
2026-02-05 19:40:37 +00:00
import math
import argparse
from dataclasses import asdict
2026-03-04 23:55:24 +00:00
from contextlib import contextmanager
2025-10-16 15:46:18 -07:00
2025-10-13 06:49:24 -07:00
import wandb
import torch
2026-03-04 23:55:24 +00:00
import torch . distributed as dist
2025-10-13 06:49:24 -07:00
2026-03-04 23:55:24 +00:00
from nanochat . gpt import GPT , GPTConfig , Linear
2026-01-13 20:05:47 +00:00
from nanochat . dataloader import tokenizing_distributed_data_loader_bos_bestfit , tokenizing_distributed_data_loader_with_state_bos_bestfit
2026-03-04 23:55:24 +00:00
from nanochat . common import compute_init , compute_cleanup , print0 , DummyWandb , print_banner , get_base_dir , autodetect_device_type , get_peak_flops , COMPUTE_DTYPE , COMPUTE_DTYPE_REASON , is_ddp_initialized
2025-10-13 06:49:24 -07:00
from nanochat . tokenizer import get_tokenizer , get_token_bytes
2025-11-13 15:34:40 +00:00
from nanochat . checkpoint_manager import save_checkpoint , load_checkpoint
2025-10-13 06:49:24 -07:00
from nanochat . loss_eval import evaluate_bpb
from nanochat . engine import Engine
2026-01-16 17:37:51 +00:00
from nanochat . flash_attention import HAS_FA3
2026-02-01 05:03:44 +00:00
from scripts . base_eval import evaluate_core
2025-10-13 06:49:24 -07:00
print_banner ( )
# -----------------------------------------------------------------------------
2026-01-04 19:14:23 +00:00
# CLI arguments
parser = argparse . ArgumentParser ( description = " Pretrain base model " )
# Logging
parser . add_argument ( " --run " , type = str , default = " dummy " , help = " wandb run name ( ' dummy ' disables wandb logging) " )
2025-10-16 16:14:38 +00:00
# Runtime
2026-01-13 22:45:27 +00:00
parser . add_argument ( " --device-type " , type = str , default = " " , help = " cuda|cpu|mps (empty = autodetect) " )
2026-02-03 20:51:26 +00:00
# FP8 training
parser . add_argument ( " --fp8 " , action = " store_true " , help = " enable FP8 training (requires H100+ GPU and torchao) " )
parser . add_argument ( " --fp8-recipe " , type = str , default = " tensorwise " , choices = [ " rowwise " , " tensorwise " ] , help = " FP8 scaling recipe: tensorwise (faster, recommended) or rowwise (more accurate but slower) " )
2025-10-13 06:49:24 -07:00
# Model architecture
2026-01-04 19:14:23 +00:00
parser . add_argument ( " --depth " , type = int , default = 20 , help = " depth of the Transformer model " )
2026-01-13 22:45:27 +00:00
parser . add_argument ( " --aspect-ratio " , type = int , default = 64 , help = " model_dim = depth * aspect_ratio " )
parser . add_argument ( " --head-dim " , type = int , default = 128 , help = " target head dimension for attention " )
parser . add_argument ( " --max-seq-len " , type = int , default = 2048 , help = " max context length " )
parser . add_argument ( " --window-pattern " , type = str , default = " SSSL " , help = " sliding window pattern tiled across layers: L=full, S=half context (e.g. ' SSL ' ) " )
2026-01-04 19:14:23 +00:00
# Training horizon (only one used, in order of precedence)
2026-01-13 22:45:27 +00:00
parser . add_argument ( " --num-iterations " , type = int , default = - 1 , help = " explicit number of optimization steps (-1 = disable) " )
parser . add_argument ( " --target-flops " , type = float , default = - 1.0 , help = " calculate num_iterations to reach target_flops (-1 = disable) " )
2026-03-24 19:25:50 +00:00
parser . add_argument ( " --target-param-data-ratio " , type = float , default = 12 , help = " calculate num_iterations to maintain data:param ratio (Chinchilla=20, -1 = disable) " )
2025-10-13 06:49:24 -07:00
# Optimization
2026-02-05 19:40:37 +00:00
parser . add_argument ( " --device-batch-size " , type = int , default = 32 , help = " per-device batch size. good number to reduce to 16,8,4,... if you OOM on VRAM. " )
parser . add_argument ( " --total-batch-size " , type = int , default = - 1 , help = " total batch size in tokens. decent numbers are e.g. 524288. (-1 = auto-compute optimal) " )
2026-01-13 22:45:27 +00:00
parser . add_argument ( " --embedding-lr " , type = float , default = 0.3 , help = " learning rate for embedding parameters (Adam) " )
2026-03-09 20:45:17 +00:00
parser . add_argument ( " --unembedding-lr " , type = float , default = 0.008 , help = " learning rate for unembedding parameters (Adam) " )
parser . add_argument ( " --weight-decay " , type = float , default = 0.28 , help = " cautious weight decay for the Muon optimizer (for weights) " )
2026-01-13 22:45:27 +00:00
parser . add_argument ( " --matrix-lr " , type = float , default = 0.02 , help = " learning rate for matrix parameters (Muon) " )
parser . add_argument ( " --scalar-lr " , type = float , default = 0.5 , help = " learning rate for scalars (resid_lambdas, x0_lambdas) " )
2026-03-09 20:45:17 +00:00
parser . add_argument ( " --warmup-steps " , type = int , default = 40 , help = " number of steps for LR warmup " )
parser . add_argument ( " --warmdown-ratio " , type = float , default = 0.65 , help = " ratio of iterations for LR warmdown " )
parser . add_argument ( " --final-lr-frac " , type = float , default = 0.05 , help = " final LR as fraction of initial LR " )
2026-01-13 22:45:27 +00:00
parser . add_argument ( " --resume-from-step " , type = int , default = - 1 , help = " resume training from this step (-1 = disable) " )
2025-10-13 06:49:24 -07:00
# Evaluation
2026-01-13 22:45:27 +00:00
parser . add_argument ( " --eval-every " , type = int , default = 250 , help = " evaluate val bpb every N steps (-1 = disable) " )
2026-03-04 19:47:12 +00:00
parser . add_argument ( " --eval-tokens " , type = int , default = 80 * 524288 , help = " number of tokens to evaluate val loss on " )
2026-01-13 22:45:27 +00:00
parser . add_argument ( " --core-metric-every " , type = int , default = 2000 , help = " evaluate CORE metric every N steps (-1 = disable) " )
parser . add_argument ( " --core-metric-max-per-task " , type = int , default = 500 , help = " examples per task for CORE metric " )
parser . add_argument ( " --sample-every " , type = int , default = 2000 , help = " sample from model every N steps (-1 = disable) " )
parser . add_argument ( " --save-every " , type = int , default = - 1 , help = " save checkpoints every N steps (-1 = only at end) " )
2025-10-13 06:49:24 -07:00
# Output
2026-01-13 22:45:27 +00:00
parser . add_argument ( " --model-tag " , type = str , default = None , help = " override model tag for checkpoint directory name " )
2026-01-04 19:14:23 +00:00
args = parser . parse_args ( )
user_config = vars ( args ) . copy ( ) # for logging
2025-10-13 06:49:24 -07:00
# -----------------------------------------------------------------------------
2026-02-05 19:40:37 +00:00
# Compute init and wandb logging
2025-10-13 06:49:24 -07:00
2026-01-04 19:14:23 +00:00
device_type = autodetect_device_type ( ) if args . device_type == " " else args . device_type
2025-10-16 16:14:38 +00:00
ddp , ddp_rank , ddp_local_rank , ddp_world_size , device = compute_init ( device_type )
2025-10-13 06:49:24 -07:00
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
2025-10-16 16:14:38 +00:00
synchronize = torch . cuda . synchronize if device_type == " cuda " else lambda : None
get_max_memory = torch . cuda . max_memory_allocated if device_type == " cuda " else lambda : 0
2026-01-17 03:16:12 +00:00
if device_type == " cuda " :
gpu_device_name = torch . cuda . get_device_name ( 0 )
gpu_peak_flops = get_peak_flops ( gpu_device_name )
print0 ( f " GPU: { gpu_device_name } | Peak FLOPS (BF16): { gpu_peak_flops : .2e } " )
else :
gpu_peak_flops = float ( ' inf ' ) # MFU not meaningful for CPU/MPS
2026-03-04 23:55:24 +00:00
print0 ( f " COMPUTE_DTYPE: { COMPUTE_DTYPE } ( { COMPUTE_DTYPE_REASON } ) " )
2025-10-13 06:49:24 -07:00
# wandb logging init
2026-01-04 19:14:23 +00:00
use_dummy_wandb = args . run == " dummy " or not master_process
wandb_run = DummyWandb ( ) if use_dummy_wandb else wandb . init ( project = " nanochat " , name = args . run , config = user_config )
2025-10-13 06:49:24 -07:00
2026-01-16 17:37:51 +00:00
# Flash Attention status
2026-03-04 23:55:24 +00:00
from nanochat . flash_attention import USE_FA3
using_fa3 = USE_FA3
if using_fa3 :
2026-01-16 17:37:51 +00:00
print0 ( " ✓ Using Flash Attention 3 (Hopper GPU detected), efficient, new and awesome. " )
else :
print0 ( " ! " * 80 )
2026-03-04 23:55:24 +00:00
if HAS_FA3 and COMPUTE_DTYPE != torch . bfloat16 :
print0 ( f " WARNING: Flash Attention 3 only supports bf16, but COMPUTE_DTYPE= { COMPUTE_DTYPE } . Using PyTorch SDPA fallback " )
else :
print0 ( " WARNING: Flash Attention 3 not available, using PyTorch SDPA fallback " )
2026-01-16 17:37:51 +00:00
print0 ( " WARNING: Training will be less efficient without FA3 " )
if args . window_pattern != " L " :
print0 ( f " WARNING: SDPA has no support for sliding window attention (window_pattern= ' { args . window_pattern } ' ). Your GPU utilization will be terrible. " )
print0 ( " WARNING: Recommend using --window-pattern L for full context attention without alternating sliding window patterns. " )
print0 ( " ! " * 80 )
2026-02-05 19:40:37 +00:00
# -----------------------------------------------------------------------------
# Tokenizer will be useful for evaluation and also we need the vocab size to init the model
2025-10-13 06:49:24 -07:00
tokenizer = get_tokenizer ( )
token_bytes = get_token_bytes ( device = device )
vocab_size = tokenizer . get_vocab_size ( )
print0 ( f " Vocab size: { vocab_size : , } " )
# -----------------------------------------------------------------------------
# Initialize the Model
2025-11-13 15:34:40 +00:00
2026-02-05 19:40:37 +00:00
def build_model_meta ( depth ) :
""" Build a model on meta device for a given depth (shapes/dtypes only, no data). """
# Model dim is nudged up to nearest multiple of head_dim for clean division
# (FA3 requires head_dim divisible by 8, and this guarantees head_dim == args.head_dim exactly)
base_dim = depth * args . aspect_ratio
model_dim = ( ( base_dim + args . head_dim - 1 ) / / args . head_dim ) * args . head_dim
num_heads = model_dim / / args . head_dim
config = GPTConfig (
sequence_len = args . max_seq_len , vocab_size = vocab_size ,
n_layer = depth , n_head = num_heads , n_kv_head = num_heads , n_embd = model_dim ,
window_pattern = args . window_pattern ,
)
with torch . device ( " meta " ) :
model_meta = GPT ( config )
return model_meta
# Build the model, move to device, init the weights
model = build_model_meta ( args . depth ) # 1) Build on meta device (only shapes/dtypes, no data)
model_config = model . config
model_config_kwargs = asdict ( model_config )
print0 ( f " Model config: \n { json . dumps ( model_config_kwargs , indent = 2 ) } " )
model . to_empty ( device = device ) # 2) All tensors get storage on target device but with uninitialized (garbage) data
model . init_weights ( ) # 3) All tensors get initialized
2025-11-13 15:34:40 +00:00
# If we are resuming, overwrite the model parameters with those of the checkpoint
base_dir = get_base_dir ( )
2026-01-04 19:14:23 +00:00
output_dirname = args . model_tag if args . model_tag else f " d { args . depth } " # e.g. d12
2025-11-13 15:34:40 +00:00
checkpoint_dir = os . path . join ( base_dir , " base_checkpoints " , output_dirname )
2026-01-04 19:14:23 +00:00
resuming = args . resume_from_step != - 1
2025-11-13 15:34:40 +00:00
if resuming :
2026-01-04 19:14:23 +00:00
print0 ( f " Resuming optimization from step { args . resume_from_step } " )
model_data , optimizer_data , meta_data = load_checkpoint ( checkpoint_dir , args . resume_from_step , device , load_optimizer = True , rank = ddp_rank )
2025-11-13 15:34:40 +00:00
model . load_state_dict ( model_data , strict = True , assign = True )
del model_data # free up this memory after the copy
2026-02-03 20:51:26 +00:00
# -----------------------------------------------------------------------------
2026-02-05 19:40:37 +00:00
# FP8 training initialization and management (this has to be done before torch.compile)
2026-02-03 20:51:26 +00:00
# Convert Linear layers to Float8Linear if --fp8 is set
if args . fp8 :
if device_type != " cuda " :
print0 ( " Warning: FP8 training requires CUDA, ignoring --fp8 flag " )
else :
2026-02-10 18:46:39 +00:00
# our custom fp8 is simpler than torchao, written for exact API compatibility
from nanochat . fp8 import Float8LinearConfig , convert_to_float8_training
# from torchao.float8 import Float8LinearConfig, convert_to_float8_training
2026-02-03 20:51:26 +00:00
import torch . nn as nn
2026-02-18 23:17:29 +00:00
# Filter: dims must be divisible by 16 (FP8 hardware requirement) large enough
2026-02-03 20:51:26 +00:00
def fp8_module_filter ( mod : nn . Module , fqn : str ) - > bool :
if not isinstance ( mod , nn . Linear ) :
return False
if mod . in_features % 16 != 0 or mod . out_features % 16 != 0 :
return False
2026-02-18 23:17:29 +00:00
if min ( mod . in_features , mod . out_features ) < 128 :
return False
2026-02-03 20:51:26 +00:00
return True
fp8_config = Float8LinearConfig . from_recipe_name ( args . fp8_recipe )
2026-02-18 23:17:29 +00:00
num_linear = sum ( 1 for m in model . modules ( ) if isinstance ( m , nn . Linear ) )
2026-02-03 20:51:26 +00:00
convert_to_float8_training ( model , config = fp8_config , module_filter_fn = fp8_module_filter )
2026-02-18 23:17:29 +00:00
num_fp8 = sum ( 1 for m in model . modules ( ) if ' Float8 ' in type ( m ) . __name__ )
num_skipped = num_linear - num_fp8
print0 ( f " ✓ FP8 training enabled ( { args . fp8_recipe } scaling) - converted { num_fp8 } / { num_linear } linear layers, skipped { num_skipped } (too small) " )
2026-02-03 20:51:26 +00:00
# Context manager to temporarily disable FP8 so that model evaluation remains in BF16
@contextmanager
def disable_fp8 ( model ) :
""" Temporarily swap Float8Linear modules with nn.Linear for BF16 evaluation.
CastConfig is a frozen dataclass, so we can ' t mutate scaling_type. Instead,
we swap out Float8Linear modules entirely and restore them after.
"""
import torch . nn as nn
# Find all Float8Linear modules and their locations
fp8_locations = [ ] # list of (parent_module, attr_name, fp8_module)
for name , module in model . named_modules ( ) :
if ' Float8 ' in type ( module ) . __name__ :
if ' . ' in name :
parent_name , attr_name = name . rsplit ( ' . ' , 1 )
parent = model . get_submodule ( parent_name )
else :
parent = model
attr_name = name
fp8_locations . append ( ( parent , attr_name , module ) )
if not fp8_locations :
yield # No FP8 modules, nothing to do
return
2026-03-04 23:55:24 +00:00
# Swap Float8Linear -> Linear (our custom class that casts weights to match input dtype)
2026-02-03 20:51:26 +00:00
for parent , attr_name , fp8_module in fp8_locations :
2026-03-04 23:55:24 +00:00
linear = Linear (
2026-02-03 20:51:26 +00:00
fp8_module . in_features ,
fp8_module . out_features ,
bias = fp8_module . bias is not None ,
device = fp8_module . weight . device ,
dtype = fp8_module . weight . dtype ,
)
linear . weight = fp8_module . weight # share, don't copy
if fp8_module . bias is not None :
linear . bias = fp8_module . bias
setattr ( parent , attr_name , linear )
try :
yield
finally :
# Restore Float8Linear modules
for parent , attr_name , fp8_module in fp8_locations :
setattr ( parent , attr_name , fp8_module )
# -----------------------------------------------------------------------------
# Compile the model
orig_model = model # original, uncompiled model, for saving raw model state_dict and for inference/evaluation (because the shapes may change shape)
model = torch . compile ( model , dynamic = False ) # the inputs to model will never change shape so dynamic=False is safe
2025-10-13 06:49:24 -07:00
# -----------------------------------------------------------------------------
2026-02-06 19:22:28 +00:00
# Scaling laws and muP extrapolations to determine the optimal training horizon, batch size, learning rates, weight decay.
2026-02-05 19:40:37 +00:00
2026-02-06 19:22:28 +00:00
# Get the parameter counts of our model
2026-02-05 19:40:37 +00:00
param_counts = model . num_scaling_params ( )
print0 ( f " Parameter counts: " )
for key , value in param_counts . items ( ) :
print0 ( f " { key : 24s } : { value : , } " )
num_params = param_counts [ ' total ' ]
num_flops_per_token = model . estimate_flops ( )
print0 ( f " Estimated FLOPs per token: { num_flops_per_token : e } " )
2026-02-06 19:22:28 +00:00
# 1) Use scaling laws to determine the optimal training horizon in tokens
# The compute-optimal models satisfy the Tokens:Params ratio of --target-param-data-ratio (derived experimentally via scaling laws analysis).
# We've already initialized the model so we have Params. Optimal Tokens is now simply target-param-data-ratio * Params
def get_scaling_params ( m ) :
# As for which params to use exactly, transformer matrices + lm_head gives cleanest scaling laws (see dev/LOG.md Jan 27, 2026)
params_counts = m . num_scaling_params ( )
scaling_params = params_counts [ ' transformer_matrices ' ] + params_counts [ ' lm_head ' ]
return scaling_params
2026-02-05 19:40:37 +00:00
num_scaling_params = get_scaling_params ( model )
2026-02-06 19:22:28 +00:00
target_tokens = int ( args . target_param_data_ratio * num_scaling_params ) # optimal tokens for the model we are about to train
# Our reference model is d12, this is where a lot of hyperparameters are tuned and then transfered to higher depths (muP style)
d12_ref = build_model_meta ( 12 ) # creates the model on meta device
D_REF = args . target_param_data_ratio * get_scaling_params ( d12_ref ) # compute-optimal d12 training horizon in tokens (measured empirically)
B_REF = 2 * * 19 # optimal batch size at d12 ~= 524,288 tokens (measured empirically)
2026-02-05 19:40:37 +00:00
2026-02-06 19:22:28 +00:00
# 2) Now that we have the token horizon, we can calculate the optimal batch size
# We follow the Power Lines paper (Bopt ∝ D^0.383), ref: https://arxiv.org/abs/2505.13738
# The optimal batch size grows as approximately D^0.383, so e.g. if D doubles from d12 to d24, B should grow by 2^0.383 ≈ 1.3x.
total_batch_size = args . total_batch_size # user-provided override is possible
2026-02-05 19:59:46 +00:00
if total_batch_size == - 1 :
batch_size_ratio = target_tokens / D_REF
2026-02-06 19:22:28 +00:00
predicted_batch_size = B_REF * batch_size_ratio * * 0.383
total_batch_size = 2 * * round ( math . log2 ( predicted_batch_size ) ) # clamp to nearest power of 2 for efficiency
2026-02-05 19:59:46 +00:00
print0 ( f " Auto-computed optimal batch size: { total_batch_size : , } tokens " )
2026-02-05 19:40:37 +00:00
2026-02-06 19:22:28 +00:00
# 3) Knowing the batch size, we can now calculate a learning rate correction (bigger batch size allows higher learning rates)
2026-02-05 19:40:37 +00:00
batch_lr_scale = 1.0
2026-02-06 19:22:28 +00:00
batch_ratio = total_batch_size / B_REF # B/B_ref
2026-02-05 19:40:37 +00:00
if batch_ratio != 1.0 :
# SGD: linear scaling with batch size is standard (not used in nanochat)
2026-02-06 19:22:28 +00:00
# AdamW: sqrt scaling is standard: η ∝ √(B/B_ref)
# Muon: we will use the same scaling for Muon as for AdamW: η ∝ √(B/B_ref) (not studied carefully, assumption!)
batch_lr_scale = batch_ratio * * 0.5 # η ∝ √(B/B_ref)
print0 ( f " Scaling LRs by { batch_lr_scale : .4f } for batch size { total_batch_size : , } (reference: { B_REF : , } ) " )
# 4) Knowing the batch size and the token horizon, we can now calculate the appropriate weight decay scaling
# We adopt the T_epoch framework from https://arxiv.org/abs/2405.13698
# Central idea of the paper is that T_epoch = B/(η·λ·D) should remain constant.
# Above, we used learning rate scaling η ∝ √(B/B_ref). So it's a matter of ~10 lines of math to derive that to keep T_epoch constant, we need:
# λ = λ_ref · √(B/B_ref) · (D_ref/D)
# Note that these papers study AdamW, *not* Muon. We are blindly following AdamW theory for scaling hoping it ~works for Muon too.
weight_decay_scaled = args . weight_decay * math . sqrt ( total_batch_size / B_REF ) * ( D_REF / target_tokens )
if weight_decay_scaled != args . weight_decay :
2026-02-05 19:40:37 +00:00
print0 ( f " Scaling weight decay from { args . weight_decay : .6f } to { weight_decay_scaled : .6f } for depth { args . depth } " )
# -----------------------------------------------------------------------------
2026-01-29 00:50:50 +00:00
# Initialize the Optimizer (combined MuonAdamW: Muon for matrix params, AdamW for rest)
optimizer = model . setup_optimizer (
2026-02-06 19:22:28 +00:00
# AdamW hyperparameters
2026-01-07 22:11:52 +00:00
unembedding_lr = args . unembedding_lr * batch_lr_scale ,
embedding_lr = args . embedding_lr * batch_lr_scale ,
2026-02-06 19:22:28 +00:00
scalar_lr = args . scalar_lr * batch_lr_scale ,
# Muon hyperparameters
2026-01-07 22:11:52 +00:00
matrix_lr = args . matrix_lr * batch_lr_scale ,
2026-01-11 16:56:59 +00:00
weight_decay = weight_decay_scaled ,
2026-01-07 22:11:52 +00:00
)
2025-10-13 06:49:24 -07:00
2025-11-13 15:34:40 +00:00
if resuming :
2026-01-29 00:50:50 +00:00
optimizer . load_state_dict ( optimizer_data )
del optimizer_data
2025-11-13 15:34:40 +00:00
2026-03-04 23:55:24 +00:00
# -----------------------------------------------------------------------------
# GradScaler for fp16 training (bf16/fp32 don't need it — bf16 has the same exponent range as fp32)
scaler = torch . amp . GradScaler ( ) if COMPUTE_DTYPE == torch . float16 else None
if scaler is not None :
print0 ( " GradScaler enabled for fp16 training " )
2025-11-13 15:34:40 +00:00
# -----------------------------------------------------------------------------
2025-10-13 06:49:24 -07:00
# Initialize the DataLoaders for train/val
2025-11-13 15:34:40 +00:00
dataloader_resume_state_dict = None if not resuming else meta_data [ " dataloader_state_dict " ]
2026-01-13 20:05:47 +00:00
train_loader = tokenizing_distributed_data_loader_with_state_bos_bestfit ( tokenizer , args . device_batch_size , args . max_seq_len , split = " train " , device = device , resume_state_dict = dataloader_resume_state_dict )
build_val_loader = lambda : tokenizing_distributed_data_loader_bos_bestfit ( tokenizer , args . device_batch_size , args . max_seq_len , split = " val " , device = device )
2025-11-13 15:34:40 +00:00
x , y , dataloader_state_dict = next ( train_loader ) # kick off load of the very first batch of data
2025-10-13 06:49:24 -07:00
# -----------------------------------------------------------------------------
2026-02-06 19:22:28 +00:00
# Calculate the number of iterations we will train for and set up the various schedulers
# num_iterations: either it is given, or from target flops, or from target data:param ratio (in that order)
assert args . num_iterations > 0 or args . target_param_data_ratio > 0 or args . target_flops > 0
if args . num_iterations > 0 :
# Override num_iterations to a specific value if given
num_iterations = args . num_iterations
print0 ( f " Using user-provided number of iterations: { num_iterations : , } " )
elif args . target_flops > 0 :
# Calculate the number of iterations from the target flops (used in scaling laws analysis, e.g. runs/scaling_laws.sh)
num_iterations = round ( args . target_flops / ( num_flops_per_token * total_batch_size ) )
print0 ( f " Calculated number of iterations from target FLOPs: { num_iterations : , } " )
elif args . target_param_data_ratio > 0 :
# Calculate the number of iterations from the target param data ratio (the most common use case)
num_iterations = target_tokens / / total_batch_size
print0 ( f " Calculated number of iterations from target data:param ratio: { num_iterations : , } " )
else :
raise ValueError ( " No training horizon specified " )
total_tokens = total_batch_size * num_iterations # the actual number of tokens we will train for
print0 ( f " Total number of training tokens: { total_tokens : , } " )
print0 ( f " Tokens : Scaling params ratio: { total_batch_size * num_iterations / num_scaling_params : .2f } " ) # e.g. Chinchilla was ~20
print0 ( f " Total training FLOPs estimate: { num_flops_per_token * total_tokens : e } " )
2025-10-13 06:49:24 -07:00
2026-02-06 19:22:28 +00:00
# Learning rate schedule (linear warmup, constant, linear warmdown)
2025-10-13 06:49:24 -07:00
def get_lr_multiplier ( it ) :
2026-03-09 20:45:17 +00:00
warmup_iters = args . warmup_steps
2026-01-04 19:14:23 +00:00
warmdown_iters = round ( args . warmdown_ratio * num_iterations )
2025-10-13 06:49:24 -07:00
if it < warmup_iters :
return ( it + 1 ) / warmup_iters
elif it < = num_iterations - warmdown_iters :
return 1.0
else :
progress = ( num_iterations - it ) / warmdown_iters
2026-01-04 19:14:23 +00:00
return progress * 1.0 + ( 1 - progress ) * args . final_lr_frac
2025-10-13 06:49:24 -07:00
2026-03-14 17:03:06 +00:00
# Momentum scheduler for Muon optimizer (warms up to 0.97, warms down to 0.90 during LR warmdown)
2025-10-13 06:49:24 -07:00
def get_muon_momentum ( it ) :
2026-03-14 17:03:06 +00:00
warmdown_iters = round ( args . warmdown_ratio * num_iterations )
warmdown_start = num_iterations - warmdown_iters
if it < 400 :
frac = it / 400
return ( 1 - frac ) * 0.85 + frac * 0.97
elif it > = warmdown_start :
progress = ( it - warmdown_start ) / warmdown_iters
return 0.97 * ( 1 - progress ) + 0.90 * progress
else :
return 0.97
2025-10-13 06:49:24 -07:00
2026-03-09 20:45:17 +00:00
# Weight decay scheduler for Muon optimizer (cosine decay to zero over the course of training)
2026-01-11 16:56:59 +00:00
def get_weight_decay ( it ) :
2026-03-09 20:45:17 +00:00
return weight_decay_scaled * 0.5 * ( 1 + math . cos ( math . pi * it / num_iterations ) )
2026-01-11 16:56:59 +00:00
2025-11-13 15:34:40 +00:00
# -----------------------------------------------------------------------------
2026-02-06 19:22:28 +00:00
# Training loop
2025-11-13 15:34:40 +00:00
2026-02-06 19:22:28 +00:00
# Loop state (variables updated by the training loop)
2025-11-13 15:34:40 +00:00
if not resuming :
step = 0
2026-01-05 00:38:09 +00:00
val_bpb = None # will be set if eval_every > 0
2025-11-13 15:34:40 +00:00
min_val_bpb = float ( " inf " )
smooth_train_loss = 0 # EMA of training loss
total_training_time = 0 # total wall-clock time of training
else :
step = meta_data [ " step " ]
loop_state = meta_data [ " loop_state " ]
2025-11-22 11:04:20 +08:00
val_bpb = meta_data [ " val_bpb " ]
2025-11-13 15:34:40 +00:00
min_val_bpb = loop_state [ " min_val_bpb " ]
smooth_train_loss = loop_state [ " smooth_train_loss " ]
total_training_time = loop_state [ " total_training_time " ]
2026-02-06 19:22:28 +00:00
# Figure out the needed gradient accumulation micro-steps to reach the desired total batch size per step
tokens_per_fwdbwd = args . device_batch_size * args . max_seq_len # tokens per iteration for a single rank
world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks
assert total_batch_size % world_tokens_per_fwdbwd == 0
grad_accum_steps = total_batch_size / / world_tokens_per_fwdbwd
print0 ( f " Tokens / micro-batch / rank: { args . device_batch_size } x { args . max_seq_len } = { tokens_per_fwdbwd : , } " )
print0 ( f " Tokens / micro-batch: { world_tokens_per_fwdbwd : , } " )
print0 ( f " Total batch size { total_batch_size : , } => gradient accumulation steps: { grad_accum_steps } " )
# Go!
2025-11-13 15:34:40 +00:00
while True :
last_step = step == num_iterations # loop runs num_iterations+1 times so that we can eval/save at the end
2026-02-05 19:59:46 +00:00
flops_so_far = num_flops_per_token * total_batch_size * step
2025-10-13 06:49:24 -07:00
# once in a while: evaluate the val bpb (all ranks participate)
2026-01-05 00:38:09 +00:00
if args . eval_every > 0 and ( last_step or step % args . eval_every == 0 ) :
2025-10-13 06:49:24 -07:00
model . eval ( )
val_loader = build_val_loader ( )
2026-01-04 19:14:23 +00:00
eval_steps = args . eval_tokens / / ( args . device_batch_size * args . max_seq_len * ddp_world_size )
2026-03-04 23:55:24 +00:00
with disable_fp8 ( model ) :
2025-10-13 06:49:24 -07:00
val_bpb = evaluate_bpb ( model , val_loader , eval_steps , token_bytes )
2026-01-11 16:56:59 +00:00
print0 ( f " Step { step : 05d } | Validation bpb: { val_bpb : .6f } " )
2025-10-13 06:49:24 -07:00
if val_bpb < min_val_bpb :
min_val_bpb = val_bpb
wandb_run . log ( {
" step " : step ,
" total_training_flops " : flops_so_far ,
" total_training_time " : total_training_time ,
" val/bpb " : val_bpb ,
} )
model . train ( )
# once in a while: estimate the CORE metric (all ranks participate)
# use the original uncompiled model because the inputs keep changing shape
2026-02-03 20:51:26 +00:00
# disable FP8 for evaluation to use BF16 for more consistent/accurate results
2025-10-16 15:46:18 -07:00
results = { }
2026-01-04 19:14:23 +00:00
if args . core_metric_every > 0 and ( last_step or ( step > 0 and step % args . core_metric_every == 0 ) ) :
2025-10-13 06:49:24 -07:00
model . eval ( )
2026-03-04 23:55:24 +00:00
with disable_fp8 ( orig_model ) :
2026-02-01 05:03:44 +00:00
results = evaluate_core ( orig_model , tokenizer , device , max_per_task = args . core_metric_max_per_task )
2025-10-13 06:49:24 -07:00
print0 ( f " Step { step : 05d } | CORE metric: { results [ ' core_metric ' ] : .4f } " )
wandb_run . log ( {
" step " : step ,
" total_training_flops " : flops_so_far ,
" core_metric " : results [ " core_metric " ] ,
" centered_results " : results [ " centered_results " ] ,
} )
model . train ( )
# once in a while: sample from the model (only on master process)
# use the original uncompiled model because the inputs keep changing shape
2026-01-05 00:38:09 +00:00
if args . sample_every > 0 and master_process and ( last_step or ( step > 0 and step % args . sample_every == 0 ) ) :
2025-10-13 06:49:24 -07:00
model . eval ( )
prompts = [
" The capital of France is " ,
" The chemical symbol of gold is " ,
" If yesterday was Friday, then tomorrow will be " ,
" The opposite of hot is " ,
" The planets of the solar system are: " ,
" My favorite color is " ,
" If 5*x + 3 = 13, then x is " ,
]
2025-10-20 00:05:09 +00:00
engine = Engine ( orig_model , tokenizer ) # use orig_model to avoid recompilation
2025-10-13 06:49:24 -07:00
for prompt in prompts :
tokens = tokenizer ( prompt , prepend = " <|bos|> " )
2026-03-04 23:55:24 +00:00
with disable_fp8 ( orig_model ) :
2025-10-13 06:49:24 -07:00
sample , _ = engine . generate_batch ( tokens , num_samples = 1 , max_tokens = 16 , temperature = 0 )
print0 ( tokenizer . decode ( sample [ 0 ] ) )
model . train ( )
2025-11-13 15:34:40 +00:00
# save checkpoint: at the end of the run, or every save_every steps, except at the first step or the resume step
2026-01-04 19:14:23 +00:00
if last_step or ( step > 0 and step != args . resume_from_step and args . save_every > 0 and step % args . save_every == 0 ) :
2025-10-13 06:49:24 -07:00
save_checkpoint (
checkpoint_dir ,
step ,
2025-11-13 15:34:40 +00:00
orig_model . state_dict ( ) , # model parameters
2026-01-29 00:50:50 +00:00
optimizer . state_dict ( ) , # optimizer state
2025-11-13 15:34:40 +00:00
{ # metadata saved as json
2025-10-13 06:49:24 -07:00
" step " : step ,
" val_bpb " : val_bpb , # loss at last step
" model_config " : model_config_kwargs ,
" user_config " : user_config , # inputs to the training script
2026-01-04 19:14:23 +00:00
" device_batch_size " : args . device_batch_size ,
" max_seq_len " : args . max_seq_len ,
2026-02-16 14:41:53 +00:00
" total_batch_size " : total_batch_size ,
2025-11-13 15:34:40 +00:00
" dataloader_state_dict " : dataloader_state_dict ,
" loop_state " : { # all loop state (other than step) so that we can resume training
" min_val_bpb " : min_val_bpb ,
" smooth_train_loss " : smooth_train_loss ,
" total_training_time " : total_training_time ,
} ,
} ,
rank = ddp_rank ,
2025-10-13 06:49:24 -07:00
)
2025-11-13 15:34:40 +00:00
# termination conditions (TODO: possibly also add loss explosions etc.)
2025-10-13 06:49:24 -07:00
if last_step :
break
# -------------------------------------------------------------------------
# single training step
# evaluate the gradient
2025-10-16 16:14:38 +00:00
synchronize ( )
2025-10-13 06:49:24 -07:00
t0 = time . time ( )
for micro_step in range ( grad_accum_steps ) :
2026-03-04 23:55:24 +00:00
loss = model ( x , y )
2025-10-13 06:49:24 -07:00
train_loss = loss . detach ( ) # for logging
loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
2026-03-04 23:55:24 +00:00
if scaler is not None :
scaler . scale ( loss ) . backward ( )
else :
loss . backward ( )
2025-11-13 15:34:40 +00:00
x , y , dataloader_state_dict = next ( train_loader ) # prefetch the next batch while the GPU is busy with forward/backward
2026-01-29 00:50:50 +00:00
# step the optimizer
2025-10-13 06:49:24 -07:00
lrm = get_lr_multiplier ( step )
muon_momentum = get_muon_momentum ( step )
2026-01-11 16:56:59 +00:00
muon_weight_decay = get_weight_decay ( step )
2026-01-29 00:50:50 +00:00
for group in optimizer . param_groups :
group [ " lr " ] = group [ " initial_lr " ] * lrm
if group [ ' kind ' ] == ' muon ' :
group [ " momentum " ] = muon_momentum
group [ " weight_decay " ] = muon_weight_decay
2026-03-04 23:55:24 +00:00
if scaler is not None :
scaler . unscale_ ( optimizer )
# In distributed training, all ranks must agree on whether to skip the step.
# Each rank may independently encounter inf/nan gradients, so we all-reduce
# the found_inf flag (MAX = if any rank found inf, all ranks skip).
if is_ddp_initialized ( ) :
for v in scaler . _found_inf_per_device ( optimizer ) . values ( ) :
dist . all_reduce ( v , op = dist . ReduceOp . MAX )
scaler . step ( optimizer )
scaler . update ( )
else :
optimizer . step ( )
2025-10-13 06:49:24 -07:00
model . zero_grad ( set_to_none = True )
2026-01-15 23:30:11 +00:00
train_loss_f = train_loss . item ( ) # .item() is a CPU-GPU sync point
2025-10-16 16:14:38 +00:00
synchronize ( )
2025-10-13 06:49:24 -07:00
t1 = time . time ( )
dt = t1 - t0
# -------------------------------------------------------------------------
2026-01-15 23:30:11 +00:00
# logging (CPU action only)
2025-11-13 15:34:40 +00:00
ema_beta = 0.9 # EMA decay factor for some smoothing just for nicer logging
2026-01-15 23:30:11 +00:00
smooth_train_loss = ema_beta * smooth_train_loss + ( 1 - ema_beta ) * train_loss_f # EMA the training loss
2025-10-13 06:49:24 -07:00
debiased_smooth_loss = smooth_train_loss / ( 1 - ema_beta * * ( step + 1 ) ) # debias the EMA
pct_done = 100 * step / num_iterations
2026-02-05 19:59:46 +00:00
tok_per_sec = int ( total_batch_size / dt )
flops_per_sec = num_flops_per_token * total_batch_size / dt
2026-01-17 03:16:12 +00:00
mfu = 100 * flops_per_sec / ( gpu_peak_flops * ddp_world_size )
2025-10-13 06:49:24 -07:00
if step > 10 :
total_training_time + = dt # only count the time after the first 10 steps
2026-01-05 00:38:09 +00:00
# Calculate ETA based on average time per step (excluding first 10 steps)
steps_done = step - 10
if steps_done > 0 :
avg_time_per_step = total_training_time / steps_done
remaining_steps = num_iterations - step
eta_seconds = remaining_steps * avg_time_per_step
eta_str = f " | eta: { eta_seconds / 60 : .1f } m "
else :
eta_str = " "
2026-03-04 19:47:12 +00:00
epoch = f " { dataloader_state_dict [ ' epoch ' ] } pq: { dataloader_state_dict [ ' pq_idx ' ] } rg: { dataloader_state_dict [ ' rg_idx ' ] } "
2026-02-10 23:35:00 +00:00
print0 ( f " step { step : 05d } / { num_iterations : 05d } ( { pct_done : .2f } %) | loss: { debiased_smooth_loss : .6f } | lrm: { lrm : .2f } | dt: { dt * 1000 : .2f } ms | tok/sec: { tok_per_sec : , } | bf16_mfu: { mfu : .2f } | epoch: { epoch } | total time: { total_training_time / 60 : .2f } m { eta_str } " )
2025-10-13 06:49:24 -07:00
if step % 100 == 0 :
2025-11-05 21:08:30 +00:00
log_data = {
2025-10-13 06:49:24 -07:00
" step " : step ,
" total_training_flops " : flops_so_far ,
" total_training_time " : total_training_time ,
" train/loss " : debiased_smooth_loss ,
" train/lrm " : lrm ,
" train/dt " : dt ,
" train/tok_per_sec " : tok_per_sec ,
" train/mfu " : mfu ,
2026-01-13 20:05:47 +00:00
" train/epoch " : epoch ,
2025-11-05 21:08:30 +00:00
}
wandb_run . log ( log_data )
2025-10-13 06:49:24 -07:00
2025-11-13 15:34:40 +00:00
# state update
2026-02-02 01:44:30 +00:00
first_step_of_run = ( step == 0 ) or ( resuming and step == args . resume_from_step )
2025-11-13 15:34:40 +00:00
step + = 1
2026-02-02 01:44:30 +00:00
# The garbage collector is sadly a little bit overactive and for some poorly understood reason,
# it spends ~500ms scanning for cycles quite frequently, just to end up cleaning up very few tiny objects each time.
# So we manually manage and help it out here
if first_step_of_run :
gc . collect ( ) # manually collect a lot of garbage from setup
gc . freeze ( ) # immediately freeze all currently surviving objects and exclude them from GC
gc . disable ( ) # nuclear intervention here: disable GC entirely except:
elif step % 5000 == 0 : # every 5000 steps...
gc . collect ( ) # manually collect, just to be safe for very, very long runs
2025-10-13 06:49:24 -07:00
# print a few more stats
2025-10-16 16:14:38 +00:00
print0 ( f " Peak memory usage: { get_max_memory ( ) / 1024 / 1024 : .2f } MiB " )
2025-10-13 06:49:24 -07:00
print0 ( f " Total training time: { total_training_time / 60 : .2f } m " )
2026-01-05 00:38:09 +00:00
if val_bpb is not None :
2026-01-11 16:56:59 +00:00
print0 ( f " Minimum validation bpb: { min_val_bpb : .6f } " )
2025-10-13 06:49:24 -07:00
# Log to report
from nanochat . report import get_report
get_report ( ) . log ( section = " Base model training " , data = [
user_config , # CLI args
{ # stats about the training setup
" Number of parameters " : num_params ,
" Number of FLOPs per token " : f " { num_flops_per_token : e } " ,
" Calculated number of iterations " : num_iterations ,
" Number of training tokens " : total_tokens ,
2026-02-05 19:59:46 +00:00
" Tokens : Scaling params ratio " : total_batch_size * num_iterations / num_scaling_params ,
2025-10-13 06:49:24 -07:00
" DDP world size " : ddp_world_size ,
2026-03-09 20:45:17 +00:00
" warmup_steps " : args . warmup_steps ,
2026-01-04 19:14:23 +00:00
" warmdown_ratio " : args . warmdown_ratio ,
" final_lr_frac " : args . final_lr_frac ,
2025-10-13 06:49:24 -07:00
} ,
{ # stats about training outcomes
2026-01-05 00:38:09 +00:00
" Minimum validation bpb " : min_val_bpb if val_bpb is not None else None ,
2025-10-13 06:49:24 -07:00
" Final validation bpb " : val_bpb ,
2025-10-16 15:46:18 -07:00
" CORE metric estimate " : results . get ( " core_metric " , None ) ,
2025-10-13 06:49:24 -07:00
" MFU % " : f " { mfu : .2f } % " ,
" Total training flops " : f " { flops_so_far : e } " ,
" Total training time " : f " { total_training_time / 60 : .2f } m " ,
2025-10-16 16:14:38 +00:00
" Peak memory usage " : f " { get_max_memory ( ) / 1024 / 1024 : .2f } MiB " ,
2025-10-13 06:49:24 -07:00
}
] )
# cleanup
wandb_run . finish ( ) # wandb run finish
compute_cleanup ( )