Files
nanochat-omni/nanochat/common.py
T

252 lines
11 KiB
Python
Raw Normal View History

2025-10-13 06:49:24 -07:00
"""
Common utilities for nanochat.
"""
import os
import re
import logging
import urllib.request
2025-10-13 06:49:24 -07:00
import torch
import torch.distributed as dist
from filelock import FileLock
2025-10-13 06:49:24 -07:00
class ColoredFormatter(logging.Formatter):
"""Custom formatter that adds colors to log messages."""
# ANSI color codes
COLORS = {
'DEBUG': '\033[36m', # Cyan
'INFO': '\033[32m', # Green
'WARNING': '\033[33m', # Yellow
'ERROR': '\033[31m', # Red
'CRITICAL': '\033[35m', # Magenta
}
RESET = '\033[0m'
BOLD = '\033[1m'
def format(self, record):
# Add color to the level name
levelname = record.levelname
if levelname in self.COLORS:
record.levelname = f"{self.COLORS[levelname]}{self.BOLD}{levelname}{self.RESET}"
# Format the message
message = super().format(record)
# Add color to specific parts of the message
if levelname == 'INFO':
# Highlight numbers and percentages
message = re.sub(r'(\d+\.?\d*\s*(?:GB|MB|%|docs))', rf'{self.BOLD}\1{self.RESET}', message)
message = re.sub(r'(Shard \d+)', rf'{self.COLORS["INFO"]}{self.BOLD}\1{self.RESET}', message)
return message
def setup_default_logging():
handler = logging.StreamHandler()
handler.setFormatter(ColoredFormatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
logging.basicConfig(
level=logging.INFO,
handlers=[handler]
)
setup_default_logging()
logger = logging.getLogger(__name__)
def get_base_dir():
# co-locate nanochat intermediates with other cached data in ~/.cache (by default)
if os.environ.get("NANOCHAT_BASE_DIR"):
nanochat_dir = os.environ.get("NANOCHAT_BASE_DIR")
else:
home_dir = os.path.expanduser("~")
cache_dir = os.path.join(home_dir, ".cache")
nanochat_dir = os.path.join(cache_dir, "nanochat")
os.makedirs(nanochat_dir, exist_ok=True)
return nanochat_dir
def download_file_with_lock(url, filename, postprocess_fn=None):
"""
Downloads a file from a URL to a local path in the base directory.
Uses a lock file to prevent concurrent downloads among multiple ranks.
"""
base_dir = get_base_dir()
file_path = os.path.join(base_dir, filename)
lock_path = file_path + ".lock"
if os.path.exists(file_path):
return file_path
with FileLock(lock_path):
# Only a single rank can acquire this lock
# All other ranks block until it is released
2025-11-04 16:35:02 -08:00
# Recheck after acquiring lock
if os.path.exists(file_path):
return file_path
# Download the content as bytes
print(f"Downloading {url}...")
with urllib.request.urlopen(url) as response:
content = response.read() # bytes
# Write to local file
with open(file_path, 'wb') as f:
f.write(content)
print(f"Downloaded to {file_path}")
# Run the postprocess function if provided
if postprocess_fn is not None:
postprocess_fn(file_path)
return file_path
2025-10-13 06:49:24 -07:00
def print0(s="",**kwargs):
ddp_rank = int(os.environ.get('RANK', 0))
if ddp_rank == 0:
print(s, **kwargs)
def print_banner():
# Cool DOS Rebel font ASCII banner made with https://manytools.org/hacker-tools/ascii-banner/
banner = """
2025-10-18 09:31:11 -04:00
█████ █████
░░███ ░░███
████████ ██████ ████████ ██████ ██████ ░███████ ██████ ███████
░░███░░███ ░░░░░███ ░░███░░███ ███░░███ ███░░███ ░███░░███ ░░░░░███░░░███░
░███ ░███ ███████ ░███ ░███ ░███ ░███░███ ░░░ ░███ ░███ ███████ ░███
░███ ░███ ███░░███ ░███ ░███ ░███ ░███░███ ███ ░███ ░███ ███░░███ ░███ ███
████ █████░░████████ ████ █████░░██████ ░░██████ ████ █████░░███████ ░░█████
░░░░ ░░░░░ ░░░░░░░░ ░░░░ ░░░░░ ░░░░░░ ░░░░░░ ░░░░ ░░░░░ ░░░░░░░░ ░░░░░
"""
2025-10-13 06:49:24 -07:00
print0(banner)
def is_ddp_requested() -> bool:
"""
True if launched by torchrun (env present), even before init.
Used to decide whether we *should* initialize a PG.
"""
return all(k in os.environ for k in ("RANK", "LOCAL_RANK", "WORLD_SIZE"))
def is_ddp_initialized() -> bool:
"""
True if torch.distributed is available and the process group is initialized.
Used at cleanup to avoid destroying a non-existent PG.
"""
return dist.is_available() and dist.is_initialized()
2025-10-13 06:49:24 -07:00
def get_dist_info():
if is_ddp_requested():
# We rely on torchrun's env to decide if we SHOULD init.
# (Initialization itself happens in compute init.)
2025-10-13 06:49:24 -07:00
assert all(var in os.environ for var in ['RANK', 'LOCAL_RANK', 'WORLD_SIZE'])
ddp_rank = int(os.environ['RANK'])
ddp_local_rank = int(os.environ['LOCAL_RANK'])
ddp_world_size = int(os.environ['WORLD_SIZE'])
return True, ddp_rank, ddp_local_rank, ddp_world_size
else:
return False, 0, 0, 1
def autodetect_device_type():
# prefer to use CUDA if available, otherwise use MPS, otherwise fallback on CPU
if torch.cuda.is_available():
device_type = "cuda"
elif torch.backends.mps.is_available():
device_type = "mps"
else:
device_type = "cpu"
print0(f"Autodetected device type: {device_type}")
return device_type
def compute_init(device_type="cuda"): # cuda|cpu|mps
2025-10-13 06:49:24 -07:00
"""Basic initialization that we keep doing over and over, so make common."""
assert device_type in ["cuda", "mps", "cpu"], "Invalid device type atm"
if device_type == "cuda":
assert torch.cuda.is_available(), "Your PyTorch installation is not configured for CUDA but device_type is 'cuda'"
if device_type == "mps":
assert torch.backends.mps.is_available(), "Your PyTorch installation is not configured for MPS but device_type is 'mps'"
2025-10-13 06:49:24 -07:00
# Reproducibility
# Note that we set the global seeds here, but most of the code uses explicit rng objects.
# The only place where global rng might be used is nn.Module initialization of the model weights.
2025-10-13 06:49:24 -07:00
torch.manual_seed(42)
if device_type == "cuda":
torch.cuda.manual_seed(42)
2025-10-13 06:49:24 -07:00
# skipping full reproducibility for now, possibly investigate slowdown later
# torch.use_deterministic_algorithms(True)
# Precision
if device_type == "cuda":
torch.backends.cuda.matmul.fp32_precision = "tf32" # uses tf32 instead of fp32 for matmuls
2025-10-13 06:49:24 -07:00
# Distributed setup: Distributed Data Parallel (DDP), optional, and requires CUDA
is_ddp_requested, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
if is_ddp_requested and device_type == "cuda":
2025-10-13 06:49:24 -07:00
device = torch.device("cuda", ddp_local_rank)
2025-10-18 09:31:11 -04:00
torch.cuda.set_device(device) # make "cuda" default to this device
2025-10-13 06:49:24 -07:00
dist.init_process_group(backend="nccl", device_id=device)
dist.barrier()
else:
device = torch.device(device_type) # mps|cpu
2025-10-13 06:49:24 -07:00
if ddp_rank == 0:
logger.info(f"Distributed world size: {ddp_world_size}")
return is_ddp_requested, ddp_rank, ddp_local_rank, ddp_world_size, device
2025-10-13 06:49:24 -07:00
def compute_cleanup():
"""Companion function to compute_init, to clean things up before script exit"""
if is_ddp_initialized():
2025-10-13 06:49:24 -07:00
dist.destroy_process_group()
class DummyWandb:
"""Useful if we wish to not use wandb but have all the same signatures"""
def __init__(self):
pass
def log(self, *args, **kwargs):
pass
def finish(self):
pass
# hardcoded BF16 peak flops for NVIDIA A100, H100, H200, B200 GPU and AMD MI250, MI300X, MI325X, MI355X and Intel PVC
# inspired by torchtitan: https://github.com/pytorch/torchtitan/blob/main/torchtitan/tools/utils.py
def get_peak_flops(device_name: str) -> float:
if "A100" in device_name:
# data from https://www.nvidia.com/en-us/data-center/a100/
return 312e12
elif "H100" in device_name:
# data from https://www.nvidia.com/en-us/data-center/h100/
# NOTE: Specifications are one-half lower without sparsity.
if "NVL" in device_name:
return 835e12
elif "PCIe" in device_name:
return 756e12
else: # for H100 SXM and other variants
return 989e12
elif "H200" in device_name:
# data from https://www.nvidia.com/en-us/data-center/h200/
return 989e12
elif "B200" in device_name:
# data from https://nvdam.widen.net/s/wwnsxrhm2w/blackwell-datasheet-3384703
return 2.25e15
elif "MI355X" in device_name:
# MI355X data from https://www.amd.com/en/products/accelerators/instinct/mi350/mi355x.html
return 2500e12
elif "MI300X" in device_name or "MI325X" in device_name:
# MI300X data from https://www.amd.com/en/products/accelerators/instinct/mi300/mi300x.html
# MI325X data from https://www.amd.com/en/products/accelerators/instinct/mi300/mi325x.html
return 1300e12
elif "MI250X" in device_name:
# data from https://www.amd.com/en/products/accelerators/instinct/mi200/mi250x.html (per GCD)
return 191.5e12
elif "Data Center GPU Max 1550" in device_name:
# Also known as Ponte Vecchio (PVC).
# data from https://www.intel.com/content/www/us/en/docs/oneapi/optimization-guide-gpu/2025-0/intel-xe-gpu-architecture.html
# Dot Product Accumulate Systolic (DPAS):
# - Freq: 1300MHz
# - #ops: 512
# Full EU mode (i.e. 512 max compute units): 340.8 TFLOPS (BF16)
# Standard EU mode (i.e. 448 max compute units): 298.2 TFLOPS (BF16)
max_comp_units = torch.xpu.get_device_properties("xpu").max_compute_units
return 512 * max_comp_units * 1300 * 10**6
elif "l40s" in device_name:
# data from: "https://resources.nvidia.com/en-us-l40s/l40s-datasheet-28413"
return 362e12
else: # for other GPU types, assume A100
logger.warning(f"Peak flops undefined for: {device_name}, fallback to A100")
return 312e12