2025-10-13 06:49:24 -07:00
"""
Common utilities for nanochat.
"""
import os
import re
import logging
2025-10-24 14:02:48 +00:00
import urllib . request
2025-10-13 06:49:24 -07:00
import torch
import torch . distributed as dist
2025-11-04 07:22:34 +00:00
from filelock import FileLock
2025-10-13 06:49:24 -07:00
2026-03-04 23:55:24 +00:00
# The dtype used for compute (matmuls, activations). Master weights stay fp32 for optimizer precision.
# Linear layers cast their weights to this dtype in forward, replacing torch.amp.autocast.
# Override with NANOCHAT_DTYPE env var: "bfloat16", "float16", "float32"
_DTYPE_MAP = { " bfloat16 " : torch . bfloat16 , " float16 " : torch . float16 , " float32 " : torch . float32 }
def _detect_compute_dtype ( ) :
env = os . environ . get ( " NANOCHAT_DTYPE " )
if env is not None :
return _DTYPE_MAP [ env ] , f " set via NANOCHAT_DTYPE= { env } "
if torch . cuda . is_available ( ) :
# bf16 requires SM 80+ (Ampere: A100, A10, etc.)
# Older GPUs like V100 (SM 70) and T4 (SM 75) only have fp16 tensor cores
capability = torch . cuda . get_device_capability ( )
if capability > = ( 8 , 0 ) :
return torch . bfloat16 , f " auto-detected: CUDA SM { capability [ 0 ] } { capability [ 1 ] } (bf16 supported) "
# fp16 training requires GradScaler (not yet implemented), so fall back to fp32.
# Users can still force fp16 via NANOCHAT_DTYPE=float16 if they know what they're doing.
return torch . float32 , f " auto-detected: CUDA SM { capability [ 0 ] } { capability [ 1 ] } (pre-Ampere, bf16 not supported, using fp32) "
return torch . float32 , " auto-detected: no CUDA (CPU/MPS) "
COMPUTE_DTYPE , COMPUTE_DTYPE_REASON = _detect_compute_dtype ( )
2025-10-13 06:49:24 -07:00
class ColoredFormatter ( logging . Formatter ) :
""" Custom formatter that adds colors to log messages. """
# ANSI color codes
COLORS = {
' DEBUG ' : ' \033 [36m ' , # Cyan
' INFO ' : ' \033 [32m ' , # Green
' WARNING ' : ' \033 [33m ' , # Yellow
' ERROR ' : ' \033 [31m ' , # Red
' CRITICAL ' : ' \033 [35m ' , # Magenta
}
RESET = ' \033 [0m '
BOLD = ' \033 [1m '
def format ( self , record ) :
# Add color to the level name
levelname = record . levelname
if levelname in self . COLORS :
record . levelname = f " { self . COLORS [ levelname ] } { self . BOLD } { levelname } { self . RESET } "
# Format the message
message = super ( ) . format ( record )
# Add color to specific parts of the message
if levelname == ' INFO ' :
# Highlight numbers and percentages
message = re . sub ( r ' ( \ d+ \ .? \ d* \ s*(?:GB|MB| % |docs)) ' , rf ' { self . BOLD } \ 1 { self . RESET } ' , message )
message = re . sub ( r ' (Shard \ d+) ' , rf ' { self . COLORS [ " INFO " ] } { self . BOLD } \ 1 { self . RESET } ' , message )
return message
def setup_default_logging ( ) :
handler = logging . StreamHandler ( )
handler . setFormatter ( ColoredFormatter ( ' %(asctime)s - %(name)s - %(levelname)s - %(message)s ' ) )
logging . basicConfig (
level = logging . INFO ,
handlers = [ handler ]
)
setup_default_logging ( )
logger = logging . getLogger ( __name__ )
def get_base_dir ( ) :
# co-locate nanochat intermediates with other cached data in ~/.cache (by default)
if os . environ . get ( " NANOCHAT_BASE_DIR " ) :
nanochat_dir = os . environ . get ( " NANOCHAT_BASE_DIR " )
else :
home_dir = os . path . expanduser ( " ~ " )
cache_dir = os . path . join ( home_dir , " .cache " )
nanochat_dir = os . path . join ( cache_dir , " nanochat " )
os . makedirs ( nanochat_dir , exist_ok = True )
return nanochat_dir
2025-11-01 16:04:38 +00:00
def download_file_with_lock ( url , filename , postprocess_fn = None ) :
2025-10-24 14:02:48 +00:00
"""
Downloads a file from a URL to a local path in the base directory.
Uses a lock file to prevent concurrent downloads among multiple ranks.
"""
base_dir = get_base_dir ( )
file_path = os . path . join ( base_dir , filename )
lock_path = file_path + " .lock "
if os . path . exists ( file_path ) :
return file_path
2025-11-04 07:22:34 +00:00
with FileLock ( lock_path ) :
2025-10-24 14:02:48 +00:00
# Only a single rank can acquire this lock
# All other ranks block until it is released
2025-11-04 16:35:02 -08:00
# Recheck after acquiring lock
2025-10-24 14:02:48 +00:00
if os . path . exists ( file_path ) :
return file_path
2025-11-01 16:04:38 +00:00
# Download the content as bytes
2025-10-24 14:02:48 +00:00
print ( f " Downloading { url } ... " )
with urllib . request . urlopen ( url ) as response :
2025-11-01 16:04:38 +00:00
content = response . read ( ) # bytes
2025-10-24 14:02:48 +00:00
2025-11-01 16:04:38 +00:00
# Write to local file
with open ( file_path , ' wb ' ) as f :
2025-10-24 14:02:48 +00:00
f . write ( content )
print ( f " Downloaded to { file_path } " )
2025-11-01 16:04:38 +00:00
# Run the postprocess function if provided
if postprocess_fn is not None :
postprocess_fn ( file_path )
2025-10-24 14:02:48 +00:00
return file_path
2025-10-13 06:49:24 -07:00
def print0 ( s = " " , * * kwargs ) :
ddp_rank = int ( os . environ . get ( ' RANK ' , 0 ) )
if ddp_rank == 0 :
print ( s , * * kwargs )
def print_banner ( ) :
# Cool DOS Rebel font ASCII banner made with https://manytools.org/hacker-tools/ascii-banner/
banner = """
2025-10-18 09:31:11 -04:00
█████ █████
░░███ ░░███
████████ ██████ ████████ ██████ ██████ ░███████ ██████ ███████
░░███░░███ ░░░░░███ ░░███░░███ ███░░███ ███░░███ ░███░░███ ░░░░░███░░░███░
░███ ░███ ███████ ░███ ░███ ░███ ░███░███ ░░░ ░███ ░███ ███████ ░███
░███ ░███ ███░░███ ░███ ░███ ░███ ░███░███ ███ ░███ ░███ ███░░███ ░███ ███
████ █████░░████████ ████ █████░░██████ ░░██████ ████ █████░░███████ ░░█████
░░░░ ░░░░░ ░░░░░░░░ ░░░░ ░░░░░ ░░░░░░ ░░░░░░ ░░░░ ░░░░░ ░░░░░░░░ ░░░░░
"""
2025-10-13 06:49:24 -07:00
print0 ( banner )
2025-12-27 23:27:40 -05:00
def is_ddp_requested ( ) - > bool :
"""
True if launched by torchrun (env present), even before init.
Used to decide whether we *should* initialize a PG.
"""
return all ( k in os . environ for k in ( " RANK " , " LOCAL_RANK " , " WORLD_SIZE " ) )
def is_ddp_initialized ( ) - > bool :
"""
True if torch.distributed is available and the process group is initialized.
Used at cleanup to avoid destroying a non-existent PG.
"""
return dist . is_available ( ) and dist . is_initialized ( )
2025-10-13 06:49:24 -07:00
def get_dist_info ( ) :
2025-12-27 23:27:40 -05:00
if is_ddp_requested ( ) :
# We rely on torchrun's env to decide if we SHOULD init.
# (Initialization itself happens in compute init.)
2025-10-13 06:49:24 -07:00
assert all ( var in os . environ for var in [ ' RANK ' , ' LOCAL_RANK ' , ' WORLD_SIZE ' ] )
ddp_rank = int ( os . environ [ ' RANK ' ] )
ddp_local_rank = int ( os . environ [ ' LOCAL_RANK ' ] )
ddp_world_size = int ( os . environ [ ' WORLD_SIZE ' ] )
return True , ddp_rank , ddp_local_rank , ddp_world_size
else :
return False , 0 , 0 , 1
2025-10-16 10:26:19 -07:00
def autodetect_device_type ( ) :
# prefer to use CUDA if available, otherwise use MPS, otherwise fallback on CPU
if torch . cuda . is_available ( ) :
device_type = " cuda "
2025-10-16 15:46:18 -07:00
elif torch . backends . mps . is_available ( ) :
2025-10-16 10:26:19 -07:00
device_type = " mps "
2025-10-16 15:46:18 -07:00
else :
device_type = " cpu "
2025-10-16 10:26:19 -07:00
print0 ( f " Autodetected device type: { device_type } " )
return device_type
2025-10-16 10:04:43 -07:00
def compute_init ( device_type = " cuda " ) : # cuda|cpu|mps
2025-10-13 06:49:24 -07:00
""" Basic initialization that we keep doing over and over, so make common. """
2025-10-16 10:04:43 -07:00
assert device_type in [ " cuda " , " mps " , " cpu " ] , " Invalid device type atm "
if device_type == " cuda " :
assert torch . cuda . is_available ( ) , " Your PyTorch installation is not configured for CUDA but device_type is ' cuda ' "
if device_type == " mps " :
assert torch . backends . mps . is_available ( ) , " Your PyTorch installation is not configured for MPS but device_type is ' mps ' "
2025-10-13 06:49:24 -07:00
# Reproducibility
2025-11-13 15:34:40 +00:00
# Note that we set the global seeds here, but most of the code uses explicit rng objects.
# The only place where global rng might be used is nn.Module initialization of the model weights.
2025-10-13 06:49:24 -07:00
torch . manual_seed ( 42 )
2025-10-16 16:14:38 +00:00
if device_type == " cuda " :
torch . cuda . manual_seed ( 42 )
2025-10-13 06:49:24 -07:00
# skipping full reproducibility for now, possibly investigate slowdown later
# torch.use_deterministic_algorithms(True)
# Precision
2025-10-16 10:04:43 -07:00
if device_type == " cuda " :
2026-02-18 10:42:11 -05:00
torch . set_float32_matmul_precision ( " high " ) # uses tf32 instead of fp32 for matmuls, see https://docs.pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html
2025-10-13 06:49:24 -07:00
2025-10-16 16:14:38 +00:00
# Distributed setup: Distributed Data Parallel (DDP), optional, and requires CUDA
2025-12-27 23:27:40 -05:00
is_ddp_requested , ddp_rank , ddp_local_rank , ddp_world_size = get_dist_info ( )
if is_ddp_requested and device_type == " cuda " :
2025-10-13 06:49:24 -07:00
device = torch . device ( " cuda " , ddp_local_rank )
2025-10-18 09:31:11 -04:00
torch . cuda . set_device ( device ) # make "cuda" default to this device
2025-10-13 06:49:24 -07:00
dist . init_process_group ( backend = " nccl " , device_id = device )
dist . barrier ( )
else :
2025-10-17 08:35:41 -07:00
device = torch . device ( device_type ) # mps|cpu
2025-10-13 06:49:24 -07:00
if ddp_rank == 0 :
logger . info ( f " Distributed world size: { ddp_world_size } " )
2025-12-27 23:27:40 -05:00
return is_ddp_requested , ddp_rank , ddp_local_rank , ddp_world_size , device
2025-10-13 06:49:24 -07:00
def compute_cleanup ( ) :
""" Companion function to compute_init, to clean things up before script exit """
2025-12-27 23:27:40 -05:00
if is_ddp_initialized ( ) :
2025-10-13 06:49:24 -07:00
dist . destroy_process_group ( )
class DummyWandb :
""" Useful if we wish to not use wandb but have all the same signatures """
def __init__ ( self ) :
pass
def log ( self , * args , * * kwargs ) :
pass
def finish ( self ) :
pass
2026-01-17 03:16:12 +00:00
2026-01-17 03:22:20 +00:00
# hardcoded BF16 peak flops for various GPUs
2026-01-17 03:16:12 +00:00
# inspired by torchtitan: https://github.com/pytorch/torchtitan/blob/main/torchtitan/tools/utils.py
2026-01-17 03:22:20 +00:00
# and PR: https://github.com/karpathy/nanochat/pull/147
2026-01-17 03:16:12 +00:00
def get_peak_flops ( device_name : str ) - > float :
2026-01-17 03:22:20 +00:00
name = device_name . lower ( )
2026-02-01 04:45:06 +01:00
# Table order matters: more specific patterns first.
_PEAK_FLOPS_TABLE = (
# NVIDIA Blackwell
( [ " gb200 " ] , 2.5e15 ) ,
( [ " grace blackwell " ] , 2.5e15 ) ,
( [ " b200 " ] , 2.25e15 ) ,
( [ " b100 " ] , 1.8e15 ) ,
# NVIDIA Hopper
( [ " h200 " , " nvl " ] , 836e12 ) ,
( [ " h200 " , " pcie " ] , 836e12 ) ,
( [ " h200 " ] , 989e12 ) ,
( [ " h100 " , " nvl " ] , 835e12 ) ,
( [ " h100 " , " pcie " ] , 756e12 ) ,
( [ " h100 " ] , 989e12 ) ,
( [ " h800 " , " nvl " ] , 989e12 ) ,
( [ " h800 " ] , 756e12 ) ,
# NVIDIA Ampere data center
( [ " a100 " ] , 312e12 ) ,
( [ " a800 " ] , 312e12 ) ,
( [ " a40 " ] , 149.7e12 ) ,
( [ " a30 " ] , 165e12 ) ,
# NVIDIA Ada data center
( [ " l40s " ] , 362e12 ) ,
( [ " l40-s " ] , 362e12 ) ,
( [ " l40 s " ] , 362e12 ) ,
( [ " l4 " ] , 121e12 ) ,
# AMD CDNA accelerators
( [ " mi355 " ] , 2.5e15 ) ,
( [ " mi325 " ] , 1.3074e15 ) ,
( [ " mi300x " ] , 1.3074e15 ) ,
( [ " mi300a " ] , 980.6e12 ) ,
( [ " mi250x " ] , 383e12 ) ,
( [ " mi250 " ] , 362.1e12 ) ,
# Consumer RTX
( [ " 5090 " ] , 209.5e12 ) ,
( [ " 4090 " ] , 165.2e12 ) ,
( [ " 3090 " ] , 71e12 ) ,
)
for patterns , flops in _PEAK_FLOPS_TABLE :
if all ( p in name for p in patterns ) :
return flops
2026-01-17 03:22:20 +00:00
if " data center gpu max 1550 " in name :
# Ponte Vecchio (PVC) - dynamic based on compute units
2026-01-17 03:16:12 +00:00
max_comp_units = torch . xpu . get_device_properties ( " xpu " ) . max_compute_units
return 512 * max_comp_units * 1300 * 10 * * 6
2026-01-17 03:22:20 +00:00
# Unknown GPU - return inf so MFU shows as 0% rather than a wrong guess
logger . warning ( f " Peak flops undefined for: { device_name } , MFU will show as 0% " )
return float ( ' inf ' )