2025-10-13 06:49:24 -07:00
"""
New and upgraded chat mode because a lot of the code has changed since the last one.
Intended to be run single GPU only atm:
2026-01-31 19:12:25 +00:00
python -m scripts.chat_cli
2025-10-13 06:49:24 -07:00
"""
import argparse
import torch
2025-10-20 10:15:17 -07:00
from nanochat . common import compute_init , autodetect_device_type
2025-10-13 06:49:24 -07:00
from nanochat . engine import Engine
from nanochat . checkpoint_manager import load_model
parser = argparse . ArgumentParser ( description = ' Chat with the model ' )
2026-02-02 17:33:46 +01:00
parser . add_argument ( ' -i ' , ' --source ' , type = str , default = " sft " , help = " Source of the model: sft|rl " )
2025-10-13 06:49:24 -07:00
parser . add_argument ( ' -g ' , ' --model-tag ' , type = str , default = None , help = ' Model tag to load ' )
parser . add_argument ( ' -s ' , ' --step ' , type = int , default = None , help = ' Step to load ' )
parser . add_argument ( ' -p ' , ' --prompt ' , type = str , default = ' ' , help = ' Prompt the model, get a single response back ' )
parser . add_argument ( ' -t ' , ' --temperature ' , type = float , default = 0.6 , help = ' Temperature for generation ' )
parser . add_argument ( ' -k ' , ' --top-k ' , type = int , default = 50 , help = ' Top-k sampling parameter ' )
2025-10-20 10:15:17 -07:00
parser . add_argument ( ' --device-type ' , type = str , default = ' ' , choices = [ ' cuda ' , ' cpu ' , ' mps ' ] , help = ' Device type for evaluation: cuda|cpu|mps. empty => autodetect ' )
2025-10-13 06:49:24 -07:00
args = parser . parse_args ( )
# Init the model and tokenizer
2025-10-20 10:15:17 -07:00
device_type = autodetect_device_type ( ) if args . device_type == " " else args . device_type
ddp , ddp_rank , ddp_local_rank , ddp_world_size , device = compute_init ( device_type )
2025-10-13 06:49:24 -07:00
model , tokenizer , meta = load_model ( args . source , device , phase = " eval " , model_tag = args . model_tag , step = args . step )
# Special tokens for the chat state machine
bos = tokenizer . get_bos_token_id ( )
user_start , user_end = tokenizer . encode_special ( " <|user_start|> " ) , tokenizer . encode_special ( " <|user_end|> " )
assistant_start , assistant_end = tokenizer . encode_special ( " <|assistant_start|> " ) , tokenizer . encode_special ( " <|assistant_end|> " )
# Create Engine for efficient generation
engine = Engine ( model , tokenizer )
print ( " \n NanoChat Interactive Mode " )
print ( " - " * 50 )
print ( " Type ' quit ' or ' exit ' to end the conversation " )
print ( " Type ' clear ' to start a new conversation " )
print ( " - " * 50 )
conversation_tokens = [ bos ]
while True :
if args . prompt :
# Get the prompt from the launch command
user_input = args . prompt
else :
# Get the prompt interactively from the console
try :
user_input = input ( " \n User: " ) . strip ( )
except ( EOFError , KeyboardInterrupt ) :
print ( " \n Goodbye! " )
break
# Handle special commands
if user_input . lower ( ) in [ ' quit ' , ' exit ' ] :
print ( " Goodbye! " )
break
if user_input . lower ( ) == ' clear ' :
conversation_tokens = [ bos ]
print ( " Conversation cleared. " )
continue
if not user_input :
continue
# Add User message to the conversation
conversation_tokens . append ( user_start )
conversation_tokens . extend ( tokenizer . encode ( user_input ) )
conversation_tokens . append ( user_end )
# Kick off the assistant
conversation_tokens . append ( assistant_start )
generate_kwargs = {
" num_samples " : 1 ,
" max_tokens " : 256 ,
" temperature " : args . temperature ,
" top_k " : args . top_k ,
}
response_tokens = [ ]
print ( " \n Assistant: " , end = " " , flush = True )
2026-03-04 23:55:24 +00:00
for token_column , token_masks in engine . generate ( conversation_tokens , * * generate_kwargs ) :
token = token_column [ 0 ] # pop the batch dimension (num_samples=1)
response_tokens . append ( token )
token_text = tokenizer . decode ( [ token ] )
print ( token_text , end = " " , flush = True )
2025-10-13 06:49:24 -07:00
print ( )
# we have to ensure that the assistant end token is the last token
# so even if generation ends due to max tokens, we have to append it to the end
if response_tokens [ - 1 ] != assistant_end :
response_tokens . append ( assistant_end )
conversation_tokens . extend ( response_tokens )
# In the prompt mode, we only want a single response and exit
if args . prompt :
break