2025-10-13 06:49:24 -07:00
#!/usr/bin/env python3
"""
Unified web chat server - serves both UI and API from a single FastAPI instance.
2025-10-15 19:12:19 +00:00
Uses data parallelism to distribute requests across multiple GPUs. Each GPU loads
a full copy of the model, and incoming requests are distributed to available workers.
Launch examples:
- single available GPU (default)
python -m scripts.chat_web
- 4 GPUs
python -m scripts.chat_web --num-gpus 4
To chat, open the URL printed in the console. (If on cloud box, make sure to use public IP)
Endpoints:
GET / - Chat UI
POST /chat/completions - Chat API (streaming only)
GET /health - Health check with worker pool status
GET /stats - Worker pool statistics and GPU utilization
2025-10-15 19:42:54 +00:00
Abuse Prevention:
- Maximum 500 messages per request
- Maximum 8000 characters per message
- Maximum 32000 characters total conversation length
- Temperature clamped to 0.0-2.0
- Top-k clamped to 1-200
- Max tokens clamped to 1-4096
2025-10-13 06:49:24 -07:00
"""
import argparse
import json
import os
import torch
2025-10-15 19:12:19 +00:00
import asyncio
2025-10-15 19:51:06 +00:00
import logging
2025-10-16 01:28:37 +00:00
import random
2025-10-13 06:49:24 -07:00
from contextlib import asynccontextmanager
2025-10-15 19:42:54 +00:00
from fastapi import FastAPI , HTTPException
2025-10-13 06:49:24 -07:00
from fastapi . middleware . cors import CORSMiddleware
from fastapi . responses import StreamingResponse , HTMLResponse , FileResponse
from pydantic import BaseModel
from typing import List , Optional , AsyncGenerator
2025-10-15 19:12:19 +00:00
from dataclasses import dataclass
2025-10-20 10:15:17 -07:00
from contextlib import nullcontext
from nanochat . common import compute_init , autodetect_device_type
2025-10-13 06:49:24 -07:00
from nanochat . checkpoint_manager import load_model
from nanochat . engine import Engine
2025-10-15 19:42:54 +00:00
# Abuse prevention limits
MAX_MESSAGES_PER_REQUEST = 500
MAX_MESSAGE_LENGTH = 8000
MAX_TOTAL_CONVERSATION_LENGTH = 32000
MIN_TEMPERATURE = 0.0
MAX_TEMPERATURE = 2.0
MIN_TOP_K = 1
MAX_TOP_K = 200
MIN_MAX_TOKENS = 1
MAX_MAX_TOKENS = 4096
2025-10-13 06:49:24 -07:00
parser = argparse . ArgumentParser ( description = ' NanoChat Web Server ' )
2025-10-15 19:12:19 +00:00
parser . add_argument ( ' -n ' , ' --num-gpus ' , type = int , default = 1 , help = ' Number of GPUs to use (default: 1) ' )
2025-10-13 06:49:24 -07:00
parser . add_argument ( ' -i ' , ' --source ' , type = str , default = " sft " , help = " Source of the model: sft|mid|rl " )
parser . add_argument ( ' -t ' , ' --temperature ' , type = float , default = 0.8 , help = ' Default temperature for generation ' )
parser . add_argument ( ' -k ' , ' --top-k ' , type = int , default = 50 , help = ' Default top-k sampling parameter ' )
parser . add_argument ( ' -m ' , ' --max-tokens ' , type = int , default = 512 , help = ' Default max tokens for generation ' )
parser . add_argument ( ' -g ' , ' --model-tag ' , type = str , default = None , help = ' Model tag to load ' )
parser . add_argument ( ' -s ' , ' --step ' , type = int , default = None , help = ' Step to load ' )
parser . add_argument ( ' -p ' , ' --port ' , type = int , default = 8000 , help = ' Port to run the server on ' )
2025-10-20 10:15:17 -07:00
parser . add_argument ( ' -d ' , ' --dtype ' , type = str , default = ' bfloat16 ' , choices = [ ' float32 ' , ' bfloat16 ' ] )
parser . add_argument ( ' --device-type ' , type = str , default = ' ' , choices = [ ' cuda ' , ' cpu ' , ' mps ' ] , help = ' Device type for evaluation: cuda|cpu|mps. empty => autodetect ' )
2025-10-13 06:49:24 -07:00
parser . add_argument ( ' --host ' , type = str , default = ' 0.0.0.0 ' , help = ' Host to bind the server to ' )
args = parser . parse_args ( )
2025-10-15 19:51:06 +00:00
# Configure logging for conversation traffic
logging . basicConfig (
level = logging . INFO ,
format = ' %(asctime)s - %(message)s ' ,
datefmt = ' % Y- % m- %d % H: % M: % S '
)
logger = logging . getLogger ( __name__ )
2025-10-20 10:15:17 -07:00
device_type = autodetect_device_type ( ) if args . device_type == " " else args . device_type
ddp , ddp_rank , ddp_local_rank , ddp_world_size , device = compute_init ( device_type )
ptdtype = torch . float32 if args . dtype == ' float32 ' else torch . bfloat16
2025-10-15 19:12:19 +00:00
@dataclass
class Worker :
""" A worker with a model loaded on a specific GPU. """
gpu_id : int
device : torch . device
engine : Engine
tokenizer : object
autocast_ctx : torch . amp . autocast
class WorkerPool :
""" Pool of workers, each with a model replica on a different GPU. """
def __init__ ( self , num_gpus : Optional [ int ] = None ) :
2025-10-20 10:15:17 -07:00
if num_gpus is None :
if device_type == " cuda " :
num_gpus = torch . cuda . device_count ( )
else :
num_gpus = 1 # e.g. cpu|mps
self . num_gpus = num_gpus
2025-10-15 19:12:19 +00:00
self . workers : List [ Worker ] = [ ]
self . available_workers : asyncio . Queue = asyncio . Queue ( )
async def initialize ( self , source : str , model_tag : Optional [ str ] = None , step : Optional [ int ] = None ) :
""" Load model on each GPU. """
print ( f " Initializing worker pool with { self . num_gpus } GPUs... " )
2025-10-20 10:15:17 -07:00
if self . num_gpus > 1 :
assert device_type == " cuda " , " Only CUDA supports multiple workers/GPUs. cpu|mps does not. "
2025-10-15 19:12:19 +00:00
for gpu_id in range ( self . num_gpus ) :
2025-10-20 10:15:17 -07:00
if device_type == " cuda " :
device = torch . device ( f " cuda: { gpu_id } " )
print ( f " Loading model on GPU { gpu_id } ... " )
else :
device = torch . device ( device_type ) # e.g. cpu|mps
print ( f " Loading model on { device_type } ... " )
2025-10-15 19:12:19 +00:00
model , tokenizer , _ = load_model ( source , device , phase = " eval " , model_tag = model_tag , step = step )
engine = Engine ( model , tokenizer )
2025-10-20 10:15:17 -07:00
autocast_ctx = torch . amp . autocast ( device_type = device_type , dtype = ptdtype ) if device_type == " cuda " else nullcontext ( )
2025-10-15 19:12:19 +00:00
worker = Worker (
gpu_id = gpu_id ,
device = device ,
engine = engine ,
tokenizer = tokenizer ,
autocast_ctx = autocast_ctx
)
self . workers . append ( worker )
await self . available_workers . put ( worker )
print ( f " All { self . num_gpus } workers initialized! " )
async def acquire_worker ( self ) - > Worker :
""" Get an available worker from the pool. """
return await self . available_workers . get ( )
async def release_worker ( self , worker : Worker ) :
""" Return a worker to the pool. """
await self . available_workers . put ( worker )
2025-10-13 06:49:24 -07:00
class ChatMessage ( BaseModel ) :
role : str
content : str
class ChatRequest ( BaseModel ) :
messages : List [ ChatMessage ]
temperature : Optional [ float ] = None
max_tokens : Optional [ int ] = None
top_k : Optional [ int ] = None
2025-10-15 19:42:54 +00:00
def validate_chat_request ( request : ChatRequest ) :
""" Validate chat request to prevent abuse. """
# Check number of messages
if len ( request . messages ) == 0 :
raise HTTPException ( status_code = 400 , detail = " At least one message is required " )
if len ( request . messages ) > MAX_MESSAGES_PER_REQUEST :
raise HTTPException (
status_code = 400 ,
detail = f " Too many messages. Maximum { MAX_MESSAGES_PER_REQUEST } messages allowed per request "
)
# Check individual message lengths and total conversation length
total_length = 0
for i , message in enumerate ( request . messages ) :
if not message . content :
raise HTTPException ( status_code = 400 , detail = f " Message { i } has empty content " )
msg_length = len ( message . content )
if msg_length > MAX_MESSAGE_LENGTH :
raise HTTPException (
status_code = 400 ,
detail = f " Message { i } is too long. Maximum { MAX_MESSAGE_LENGTH } characters allowed per message "
)
total_length + = msg_length
if total_length > MAX_TOTAL_CONVERSATION_LENGTH :
raise HTTPException (
status_code = 400 ,
detail = f " Total conversation is too long. Maximum { MAX_TOTAL_CONVERSATION_LENGTH } characters allowed "
)
# Validate role values
for i , message in enumerate ( request . messages ) :
if message . role not in [ " user " , " assistant " ] :
raise HTTPException (
status_code = 400 ,
detail = f " Message { i } has invalid role. Must be ' user ' , ' assistant ' , or ' system ' "
)
# Validate temperature
if request . temperature is not None :
if not ( MIN_TEMPERATURE < = request . temperature < = MAX_TEMPERATURE ) :
raise HTTPException (
status_code = 400 ,
detail = f " Temperature must be between { MIN_TEMPERATURE } and { MAX_TEMPERATURE } "
)
# Validate top_k
if request . top_k is not None :
if not ( MIN_TOP_K < = request . top_k < = MAX_TOP_K ) :
raise HTTPException (
status_code = 400 ,
detail = f " top_k must be between { MIN_TOP_K } and { MAX_TOP_K } "
)
# Validate max_tokens
if request . max_tokens is not None :
if not ( MIN_MAX_TOKENS < = request . max_tokens < = MAX_MAX_TOKENS ) :
raise HTTPException (
status_code = 400 ,
detail = f " max_tokens must be between { MIN_MAX_TOKENS } and { MAX_MAX_TOKENS } "
)
2025-10-13 06:49:24 -07:00
@asynccontextmanager
async def lifespan ( app : FastAPI ) :
2025-10-15 19:12:19 +00:00
""" Load models on all GPUs on startup. """
print ( " Loading nanochat models across GPUs... " )
app . state . worker_pool = WorkerPool ( num_gpus = args . num_gpus )
await app . state . worker_pool . initialize ( args . source , model_tag = args . model_tag , step = args . step )
2025-10-13 06:49:24 -07:00
print ( f " Server ready at http://localhost: { args . port } " )
yield
app = FastAPI ( lifespan = lifespan )
app . add_middleware (
CORSMiddleware ,
allow_origins = [ " * " ] ,
allow_credentials = True ,
allow_methods = [ " * " ] ,
allow_headers = [ " * " ] ,
)
@app.get ( " / " )
async def root ( ) :
""" Serve the chat UI. """
ui_html_path = os . path . join ( " nanochat " , " ui.html " )
2025-11-03 21:27:12 +01:00
with open ( ui_html_path , " r " , encoding = ' utf-8 ' ) as f :
2025-10-13 06:49:24 -07:00
html_content = f . read ( )
# Replace the API_URL to use the same origin
html_content = html_content . replace (
" const API_URL = `http://$ {window.location.hostname} :8000`; " ,
" const API_URL = ' ' ; "
)
return HTMLResponse ( content = html_content )
@app.get ( " /logo.svg " )
async def logo ( ) :
""" Serve the NanoChat logo for favicon and header. """
logo_path = os . path . join ( " nanochat " , " logo.svg " )
return FileResponse ( logo_path , media_type = " image/svg+xml " )
async def generate_stream (
2025-10-15 19:12:19 +00:00
worker : Worker ,
2025-10-13 06:49:24 -07:00
tokens ,
temperature = None ,
max_new_tokens = None ,
top_k = None
) - > AsyncGenerator [ str , None ] :
""" Generate assistant response with streaming. """
temperature = temperature if temperature is not None else args . temperature
max_new_tokens = max_new_tokens if max_new_tokens is not None else args . max_tokens
top_k = top_k if top_k is not None else args . top_k
2025-10-15 19:12:19 +00:00
assistant_end = worker . tokenizer . encode_special ( " <|assistant_end|> " )
bos = worker . tokenizer . get_bos_token_id ( )
2025-10-13 06:49:24 -07:00
2025-10-15 20:29:54 +00:00
# Accumulate tokens to properly handle multi-byte UTF-8 characters (like emojis)
accumulated_tokens = [ ]
# Track the last complete UTF-8 string (without replacement characters)
last_clean_text = " "
2025-10-15 19:12:19 +00:00
with worker . autocast_ctx :
for token_column , token_masks in worker . engine . generate (
2025-10-13 06:49:24 -07:00
tokens ,
num_samples = 1 ,
max_tokens = max_new_tokens ,
temperature = temperature ,
2025-10-16 01:28:37 +00:00
top_k = top_k ,
seed = random . randint ( 0 , 2 * * 31 - 1 )
2025-10-13 06:49:24 -07:00
) :
token = token_column [ 0 ]
2025-10-15 20:29:54 +00:00
# Stopping criteria
2025-10-13 06:49:24 -07:00
if token == assistant_end or token == bos :
break
2025-10-15 20:29:54 +00:00
# Append the token to sequence
accumulated_tokens . append ( token )
# Decode all accumulated tokens to get proper UTF-8 handling
# Note that decode is a quite efficient operation, basically table lookup and string concat
current_text = worker . tokenizer . decode ( accumulated_tokens )
# Only emit text if it doesn't end with a replacement character
# This ensures we don't emit incomplete UTF-8 sequences
if not current_text . endswith ( ' � ' ) :
# Extract only the new text since last clean decode
new_text = current_text [ len ( last_clean_text ) : ]
if new_text : # Only yield if there's new content
yield f " data: { json . dumps ( { ' token ' : new_text , ' gpu ' : worker . gpu_id } , ensure_ascii = False ) } \n \n "
last_clean_text = current_text
2025-10-13 06:49:24 -07:00
yield f " data: { json . dumps ( { ' done ' : True } ) } \n \n "
@app.post ( " /chat/completions " )
async def chat_completions ( request : ChatRequest ) :
2025-10-15 19:12:19 +00:00
""" Chat completion endpoint (streaming only) - uses worker pool for multi-GPU. """
2025-10-15 19:42:54 +00:00
# Basic validation to prevent abuse
validate_chat_request ( request )
2025-10-15 19:12:19 +00:00
2025-10-15 19:51:06 +00:00
# Log incoming conversation to console
logger . info ( " = " * 20 )
for i , message in enumerate ( request . messages ) :
logger . info ( f " [ { message . role . upper ( ) } ]: { message . content } " )
logger . info ( " - " * 20 )
2025-10-15 19:12:19 +00:00
# Acquire a worker from the pool (will wait if all are busy)
2025-10-15 19:42:54 +00:00
worker_pool = app . state . worker_pool
2025-10-15 19:12:19 +00:00
worker = await worker_pool . acquire_worker ( )
try :
# Build conversation tokens
bos = worker . tokenizer . get_bos_token_id ( )
user_start = worker . tokenizer . encode_special ( " <|user_start|> " )
user_end = worker . tokenizer . encode_special ( " <|user_end|> " )
assistant_start = worker . tokenizer . encode_special ( " <|assistant_start|> " )
assistant_end = worker . tokenizer . encode_special ( " <|assistant_end|> " )
conversation_tokens = [ bos ]
for message in request . messages :
if message . role == " user " :
conversation_tokens . append ( user_start )
conversation_tokens . extend ( worker . tokenizer . encode ( message . content ) )
conversation_tokens . append ( user_end )
elif message . role == " assistant " :
conversation_tokens . append ( assistant_start )
conversation_tokens . extend ( worker . tokenizer . encode ( message . content ) )
conversation_tokens . append ( assistant_end )
conversation_tokens . append ( assistant_start )
# Streaming response with worker release after completion
2025-10-15 19:51:06 +00:00
response_tokens = [ ]
2025-10-15 19:12:19 +00:00
async def stream_and_release ( ) :
try :
async for chunk in generate_stream (
worker ,
conversation_tokens ,
temperature = request . temperature ,
max_new_tokens = request . max_tokens ,
top_k = request . top_k
) :
2025-10-15 19:51:06 +00:00
# Accumulate response for logging
chunk_data = json . loads ( chunk . replace ( " data: " , " " ) . strip ( ) )
if " token " in chunk_data :
response_tokens . append ( chunk_data [ " token " ] )
2025-10-15 19:12:19 +00:00
yield chunk
finally :
2025-10-15 19:51:06 +00:00
# Log the assistant response to console
full_response = " " . join ( response_tokens )
logger . info ( f " [ASSISTANT] (GPU { worker . gpu_id } ): { full_response } " )
logger . info ( " = " * 20 )
2025-10-15 19:12:19 +00:00
# Release worker back to pool after streaming is done
await worker_pool . release_worker ( worker )
2025-10-13 06:49:24 -07:00
return StreamingResponse (
2025-10-15 19:12:19 +00:00
stream_and_release ( ) ,
2025-10-13 06:49:24 -07:00
media_type = " text/event-stream "
)
2025-10-15 19:12:19 +00:00
except Exception as e :
# Make sure to release worker even on error
await worker_pool . release_worker ( worker )
raise e
2025-10-13 06:49:24 -07:00
@app.get ( " /health " )
async def health ( ) :
""" Health check endpoint. """
2025-10-15 19:12:19 +00:00
worker_pool = getattr ( app . state , ' worker_pool ' , None )
2025-10-13 06:49:24 -07:00
return {
" status " : " ok " ,
2025-10-15 19:12:19 +00:00
" ready " : worker_pool is not None and len ( worker_pool . workers ) > 0 ,
" num_gpus " : worker_pool . num_gpus if worker_pool else 0 ,
" available_workers " : worker_pool . available_workers . qsize ( ) if worker_pool else 0
}
@app.get ( " /stats " )
async def stats ( ) :
""" Get worker pool statistics. """
worker_pool = app . state . worker_pool
return {
" total_workers " : len ( worker_pool . workers ) ,
" available_workers " : worker_pool . available_workers . qsize ( ) ,
" busy_workers " : len ( worker_pool . workers ) - worker_pool . available_workers . qsize ( ) ,
" workers " : [
{
" gpu_id " : w . gpu_id ,
" device " : str ( w . device )
} for w in worker_pool . workers
]
2025-10-13 06:49:24 -07:00
}
if __name__ == " __main__ " :
import uvicorn
print ( f " Starting NanoChat Web Server " )
print ( f " Temperature: { args . temperature } , Top-k: { args . top_k } , Max tokens: { args . max_tokens } " )
uvicorn . run ( app , host = args . host , port = args . port )