2025-10-13 06:49:24 -07:00
"""
Utilities for saving and loading model/optim/state checkpoints.
"""
import os
import re
import glob
import json
import logging
import torch
from nanochat . common import get_base_dir
from nanochat . gpt import GPT , GPTConfig
from nanochat . tokenizer import get_tokenizer
from nanochat . common import setup_default_logging
# Set up logging
setup_default_logging ( )
logger = logging . getLogger ( __name__ )
def log0 ( message ) :
2025-11-02 14:16:43 +01:00
if int ( os . environ . get ( ' RANK ' , 0 ) ) == 0 :
2025-10-13 06:49:24 -07:00
logger . info ( message )
2026-01-11 21:49:54 +00:00
def _patch_missing_config_keys ( model_config_kwargs ) :
""" Add default values for new config keys missing in old checkpoints. """
# Old models were trained with full context (no sliding window)
if " window_pattern " not in model_config_kwargs :
model_config_kwargs [ " window_pattern " ] = " L "
2026-01-13 22:09:36 +00:00
log0 ( f " Patching missing window_pattern in model config to ' L ' " )
2026-01-11 21:49:54 +00:00
2026-01-11 20:13:12 +00:00
def _patch_missing_keys ( model_data , model_config ) :
""" Add default values for new parameters that may be missing in old checkpoints. """
n_layer = model_config . n_layer
# resid_lambdas defaults to 1.0 (identity scaling)
if " resid_lambdas " not in model_data :
model_data [ " resid_lambdas " ] = torch . ones ( n_layer )
2026-01-13 22:09:36 +00:00
log0 ( f " Patching missing resid_lambdas in model data to 1.0 " )
2026-01-11 20:13:12 +00:00
# x0_lambdas defaults to 0.0 (disabled)
if " x0_lambdas " not in model_data :
model_data [ " x0_lambdas " ] = torch . zeros ( n_layer )
2026-01-13 22:09:36 +00:00
log0 ( f " Patching missing x0_lambdas in model data to 0.0 " )
2026-01-11 20:13:12 +00:00
2025-11-13 15:34:40 +00:00
def save_checkpoint ( checkpoint_dir , step , model_data , optimizer_data , meta_data , rank = 0 ) :
if rank == 0 :
os . makedirs ( checkpoint_dir , exist_ok = True )
# Save the model state parameters
model_path = os . path . join ( checkpoint_dir , f " model_ { step : 06d } .pt " )
torch . save ( model_data , model_path )
logger . info ( f " Saved model parameters to: { model_path } " )
# Save the metadata dict as json
meta_path = os . path . join ( checkpoint_dir , f " meta_ { step : 06d } .json " )
with open ( meta_path , " w " , encoding = " utf-8 " ) as f :
json . dump ( meta_data , f , indent = 2 )
logger . info ( f " Saved metadata to: { meta_path } " )
# Note that optimizer state is sharded across ranks, so each rank must save its own.
2025-10-13 06:49:24 -07:00
if optimizer_data is not None :
2025-12-08 20:45:11 +00:00
os . makedirs ( checkpoint_dir , exist_ok = True )
2025-11-13 15:34:40 +00:00
optimizer_path = os . path . join ( checkpoint_dir , f " optim_ { step : 06d } _rank { rank : d } .pt " )
2025-10-13 06:49:24 -07:00
torch . save ( optimizer_data , optimizer_path )
2025-11-13 15:34:40 +00:00
logger . info ( f " Saved optimizer state to: { optimizer_path } " )
2025-10-13 06:49:24 -07:00
2025-11-13 15:34:40 +00:00
def load_checkpoint ( checkpoint_dir , step , device , load_optimizer = False , rank = 0 ) :
2025-10-13 06:49:24 -07:00
# Load the model state
model_path = os . path . join ( checkpoint_dir , f " model_ { step : 06d } .pt " )
model_data = torch . load ( model_path , map_location = device )
# Load the optimizer state if requested
optimizer_data = None
if load_optimizer :
2025-11-13 15:34:40 +00:00
optimizer_path = os . path . join ( checkpoint_dir , f " optim_ { step : 06d } _rank { rank : d } .pt " )
2025-10-13 06:49:24 -07:00
optimizer_data = torch . load ( optimizer_path , map_location = device )
# Load the metadata
meta_path = os . path . join ( checkpoint_dir , f " meta_ { step : 06d } .json " )
2025-11-03 21:52:02 +01:00
with open ( meta_path , " r " , encoding = " utf-8 " ) as f :
2025-10-13 06:49:24 -07:00
meta_data = json . load ( f )
return model_data , optimizer_data , meta_data
def build_model ( checkpoint_dir , step , device , phase ) :
"""
A bunch of repetitive code to build a model from a given checkpoint.
Returns:
- base model - uncompiled, not wrapped in DDP
- tokenizer
- meta data saved during base model training
"""
assert phase in [ " train " , " eval " ] , f " Invalid phase: { phase } "
2025-11-02 14:16:43 +01:00
model_data , optimizer_data , meta_data = load_checkpoint ( checkpoint_dir , step , device , load_optimizer = False )
2025-11-03 16:00:56 -05:00
if device . type in { " cpu " , " mps " } :
2025-11-01 23:38:50 +01:00
# Convert bfloat16 tensors to float for CPU inference
model_data = {
k : v . float ( ) if v . dtype == torch . bfloat16 else v
for k , v in model_data . items ( )
}
2025-10-13 06:49:24 -07:00
# Hack: fix torch compile issue, which prepends all keys with _orig_mod.
2025-11-02 23:40:37 -06:00
model_data = { k . removeprefix ( " _orig_mod. " ) : v for k , v in model_data . items ( ) }
2025-10-13 06:49:24 -07:00
model_config_kwargs = meta_data [ " model_config " ]
2026-01-11 21:49:54 +00:00
_patch_missing_config_keys ( model_config_kwargs )
2025-10-13 06:49:24 -07:00
log0 ( f " Building model with config: { model_config_kwargs } " )
model_config = GPTConfig ( * * model_config_kwargs )
2026-01-11 20:13:12 +00:00
_patch_missing_keys ( model_data , model_config )
2025-10-13 06:49:24 -07:00
with torch . device ( " meta " ) :
model = GPT ( model_config )
# Load the model state
model . to_empty ( device = device )
2025-11-02 14:16:43 +01:00
model . init_weights ( ) # note: this is dumb, but we need to init the rotary embeddings. TODO: fix model re-init
2025-10-13 06:49:24 -07:00
model . load_state_dict ( model_data , strict = True , assign = True )
# Put the model in the right training phase / mode
if phase == " eval " :
model . eval ( )
else :
model . train ( )
# Load the Tokenizer
tokenizer = get_tokenizer ( )
# Sanity check: compatibility between model and tokenizer
2026-01-15 03:20:21 +00:00
assert tokenizer . get_vocab_size ( ) == model_config_kwargs [ " vocab_size " ] , f " Tokenizer vocab size { tokenizer . get_vocab_size ( ) } does not match model config vocab size { model_config_kwargs [ ' vocab_size ' ] } "
2025-10-13 06:49:24 -07:00
return model , tokenizer , meta_data
2025-11-19 15:33:36 -05:00
def find_largest_model ( checkpoints_dir ) :
2025-10-13 06:49:24 -07:00
# attempt to guess the model tag: take the biggest model available
2025-11-19 15:33:36 -05:00
model_tags = [ f for f in os . listdir ( checkpoints_dir ) if os . path . isdir ( os . path . join ( checkpoints_dir , f ) ) ]
2025-10-13 06:49:24 -07:00
if not model_tags :
2025-11-19 15:33:36 -05:00
raise FileNotFoundError ( f " No checkpoints found in { checkpoints_dir } " )
2025-10-13 06:49:24 -07:00
# 1) normally all model tags are of the form d<number>, try that first:
candidates = [ ]
for model_tag in model_tags :
match = re . match ( r " d( \ d+) " , model_tag )
if match :
model_depth = int ( match . group ( 1 ) )
candidates . append ( ( model_depth , model_tag ) )
if candidates :
candidates . sort ( key = lambda x : x [ 0 ] , reverse = True )
return candidates [ 0 ] [ 1 ]
# 2) if that failed, take the most recently updated model:
2025-11-19 15:33:36 -05:00
model_tags . sort ( key = lambda x : os . path . getmtime ( os . path . join ( checkpoints_dir , x ) ) , reverse = True )
2025-10-14 00:18:20 +03:00
return model_tags [ 0 ]
2025-10-13 06:49:24 -07:00
def find_last_step ( checkpoint_dir ) :
# Look into checkpoint_dir and find model_<step>.pt with the highest step
checkpoint_files = glob . glob ( os . path . join ( checkpoint_dir , " model_*.pt " ) )
if not checkpoint_files :
raise FileNotFoundError ( f " No checkpoints found in { checkpoint_dir } " )
2025-11-02 14:16:43 +01:00
last_step = int ( max ( os . path . basename ( f ) . split ( " _ " ) [ - 1 ] . split ( " . " ) [ 0 ] for f in checkpoint_files ) )
2025-10-13 06:49:24 -07:00
return last_step
# -----------------------------------------------------------------------------
# convenience functions that take into account nanochat's directory structure
def load_model_from_dir ( checkpoints_dir , device , phase , model_tag = None , step = None ) :
if model_tag is None :
# guess the model tag by defaulting to the largest model
model_tag = find_largest_model ( checkpoints_dir )
log0 ( f " No model tag provided, guessing model tag: { model_tag } " )
checkpoint_dir = os . path . join ( checkpoints_dir , model_tag )
if step is None :
# guess the step by defaulting to the last step
step = find_last_step ( checkpoint_dir )
assert step is not None , f " No checkpoints found in { checkpoint_dir } "
# build the model
log0 ( f " Loading model from { checkpoint_dir } with step { step } " )
model , tokenizer , meta_data = build_model ( checkpoint_dir , step , device , phase )
return model , tokenizer , meta_data
def load_model ( source , * args , * * kwargs ) :
model_dir = {
" base " : " base_checkpoints " ,
" sft " : " chatsft_checkpoints " ,
" rl " : " chatrl_checkpoints " ,
} [ source ]
base_dir = get_base_dir ( )
checkpoints_dir = os . path . join ( base_dir , model_dir )
return load_model_from_dir ( checkpoints_dir , * args , * * kwargs )
2026-02-16 14:41:53 +00:00
def load_optimizer_state ( source , device , rank , model_tag = None , step = None ) :
""" Load just the optimizer shard for a given rank, without re-loading the model. """
model_dir = {
" base " : " base_checkpoints " ,
" sft " : " chatsft_checkpoints " ,
" rl " : " chatrl_checkpoints " ,
} [ source ]
base_dir = get_base_dir ( )
checkpoints_dir = os . path . join ( base_dir , model_dir )
if model_tag is None :
model_tag = find_largest_model ( checkpoints_dir )
checkpoint_dir = os . path . join ( checkpoints_dir , model_tag )
if step is None :
step = find_last_step ( checkpoint_dir )
optimizer_path = os . path . join ( checkpoint_dir , f " optim_ { step : 06d } _rank { rank : d } .pt " )
log0 ( f " Loading optimizer state from { optimizer_path } " )
optimizer_data = torch . load ( optimizer_path , map_location = device )
return optimizer_data