2025-10-13 06:49:24 -07:00
"""
Evaluate the Chat model.
2025-11-14 11:20:25 +01:00
All the generic code lives here, and all the evaluation-specific
2025-10-13 06:49:24 -07:00
code lives in nanochat directory and is imported from here.
Example runs:
2026-01-31 19:12:25 +00:00
python -m scripts.chat_eval -a ARC-Easy
torchrun --nproc_per_node=8 -m scripts.chat_eval -- -a ARC-Easy
2025-10-13 06:49:24 -07:00
"""
import argparse
from functools import partial
import torch
import torch . distributed as dist
2025-10-20 10:15:17 -07:00
from nanochat . common import compute_init , compute_cleanup , get_dist_info , print0 , autodetect_device_type
2025-10-13 06:49:24 -07:00
from nanochat . checkpoint_manager import load_model
from nanochat . engine import Engine
from tasks . humaneval import HumanEval
from tasks . mmlu import MMLU
from tasks . arc import ARC
from tasks . gsm8k import GSM8K
2025-10-24 14:02:48 +00:00
from tasks . spellingbee import SpellingBee
2025-10-13 06:49:24 -07:00
# -----------------------------------------------------------------------------
# Generative evaluation loop (we go one problem at a time, sample, evaluate)
def run_generative_eval ( task_object , tokenizer , model , engine , num_samples , max_new_tokens , temperature , top_k , max_problems = None ) :
ddp , ddp_rank , ddp_local_rank , ddp_world_size = get_dist_info ( )
device = model . get_device ( )
num_problems = len ( task_object ) if max_problems is None else min ( len ( task_object ) , max_problems )
# Run the evaluation
num_passed , total = 0 , 0
for i in range ( ddp_rank , num_problems , ddp_world_size ) :
conversation = task_object [ i ]
# Tokenize the prompt
encoded_prompt = tokenizer . render_for_completion ( conversation )
# Get the completions
results , _ = engine . generate_batch (
encoded_prompt ,
num_samples = num_samples ,
max_tokens = max_new_tokens ,
temperature = temperature ,
top_k = top_k ,
)
# Decode the completions as text
prefix_length = len ( encoded_prompt )
completions = [ tokenizer . decode ( result_tokens [ prefix_length : ] ) for result_tokens in results ]
# Evaluate success criteria
outcomes = [ task_object . evaluate ( conversation , completion ) for completion in completions ]
passed = any ( outcomes )
# Keep stats
total + = 1
num_passed + = int ( passed )
# Logging (overwrite the same line in the console)
print ( f " \r \033 [KRank { ddp_rank } | { num_passed } / { total } ( { 100 * num_passed / total : .2f } %) " , end = ' ' , flush = True )
# Finish the in-place progress line with a newline before final summary
print ( )
# Aggregate results across all ranks
if ddp :
num_passed_tensor = torch . tensor ( [ num_passed ] , dtype = torch . long , device = device )
total_tensor = torch . tensor ( [ total ] , dtype = torch . long , device = device )
dist . all_reduce ( num_passed_tensor , op = dist . ReduceOp . SUM )
dist . all_reduce ( total_tensor , op = dist . ReduceOp . SUM )
num_passed = num_passed_tensor . item ( )
total = total_tensor . item ( )
print0 ( " = " * 50 )
print0 ( f " Final: { num_passed } / { total } ( { 100 * num_passed / total : .2f } %) " )
# Return the accuracy
return num_passed / total
# -----------------------------------------------------------------------------
# Categorical evaluation loop
# A lot easier because we don't have to sample. Therefore, we can actually go
# batches at a time and just check the logits for correct answer choices.
def run_categorical_eval ( task_object , tokenizer , model , batch_size , max_problems = None ) :
ddp , ddp_rank , ddp_local_rank , ddp_world_size = get_dist_info ( )
device = model . get_device ( )
bos = tokenizer . get_bos_token_id ( ) # use BOS as pad token is ok, these positions are ignored
# We'll process batches of independent problems at a time because there is no sampling needed
num_problems = len ( task_object ) if max_problems is None else min ( len ( task_object ) , max_problems )
ceil_div = lambda x , y : - ( - x / / y )
num_batches = ceil_div ( num_problems , batch_size )
# Run the evaluation
letter_to_id_cache = { } # many letters will repeat often, let's save the tokenizer some work
num_passed , total = 0 , 0
for i in range ( ddp_rank , num_batches , ddp_world_size ) :
i0 , i1 = i * batch_size , min ( ( i + 1 ) * batch_size , num_problems )
# Prepare the batch of problems. They might all be of different length, so we pad/collate them.
conversations = [ task_object [ ii ] for ii in range ( i0 , i1 ) ]
prompt_ids = [ tokenizer . render_for_completion ( conversation ) for conversation in conversations ] # TODO: remake the way this works
max_length = max ( len ( ids ) for ids in prompt_ids )
answer_time_positions = [ len ( ids ) - 1 for ids in prompt_ids ] # where the last token is (and the predicted answer)
padded_prompt_ids = [ ids + [ bos ] * ( max_length - len ( ids ) ) for ids in prompt_ids ]
prompt_ids = torch . tensor ( padded_prompt_ids , dtype = torch . long , device = device )
# Get the logits for the whole batch of conversations in parallel (efficiency win here)
with torch . no_grad ( ) :
logits = model ( prompt_ids ) # (B, T, V)
# Focus on the available answer on just the letters corresponding to choices
2025-10-28 20:17:31 +01:00
# Note that this helps the evaluation a lot because it specifically narrows the focus to only the available letters
2025-10-13 06:49:24 -07:00
# The much harder alternative would be to just generate from the Assistant and check if it responded with the correct
# letter (e.g. A, B, C, D), but evaluations typically make the task easier in this way.
for idx , conversation in enumerate ( conversations ) :
# get the token ids of all the available letters of this problem
letters = conversation [ ' letters ' ]
letter_ids = [ ]
for letter in letters :
if not letter in letter_to_id_cache :
encoded_letter = tokenizer . encode ( letter )
assert len ( encoded_letter ) == 1 , " Each letter must be a single token "
letter_to_id_cache [ letter ] = encoded_letter [ 0 ]
letter_ids . append ( letter_to_id_cache [ letter ] )
# focus logits just down to the answer position and the available letters of the answer
answer_pos = answer_time_positions [ idx ]
focus_logits = logits [ idx , answer_pos , letter_ids ]
# get the argmax letter (the predicted answer)
argmax_letter_id = focus_logits . argmax ( dim = - 1 ) . item ( )
predicted_letter = letters [ argmax_letter_id ]
# evaluate the outcome
outcome = task_object . evaluate ( conversation , predicted_letter )
num_passed + = int ( outcome )
total + = 1
# Aggregate results across all ranks
if ddp :
num_passed_tensor = torch . tensor ( [ num_passed ] , dtype = torch . long , device = device )
total_tensor = torch . tensor ( [ total ] , dtype = torch . long , device = device )
dist . all_reduce ( num_passed_tensor , op = dist . ReduceOp . SUM )
dist . all_reduce ( total_tensor , op = dist . ReduceOp . SUM )
num_passed = num_passed_tensor . item ( )
total = total_tensor . item ( )
average = num_passed / total
print0 ( f " Final: { num_passed } / { total } ( { 100 * average : .2f } %) " )
return average
# -----------------------------------------------------------------------------
def run_chat_eval ( task_name , model , tokenizer , engine ,
batch_size = 1 , num_samples = 1 , max_new_tokens = 512 , temperature = 0.0 , top_k = 50 ,
max_problems = None ) :
# Create the evaluation object
task_module = {
' HumanEval ' : HumanEval ,
' MMLU ' : partial ( MMLU , subset = " all " , split = " test " ) ,
' ARC-Easy ' : partial ( ARC , subset = " ARC-Easy " , split = " test " ) ,
' ARC-Challenge ' : partial ( ARC , subset = " ARC-Challenge " , split = " test " ) ,
' GSM8K ' : partial ( GSM8K , subset = " main " , split = " test " ) ,
2025-10-24 14:02:48 +00:00
' SpellingBee ' : partial ( SpellingBee , size = 256 , split = " test " ) ,
2025-10-13 06:49:24 -07:00
} [ task_name ]
task_object = task_module ( )
# Run the evaluation
if task_object . eval_type == ' generative ' :
acc = run_generative_eval ( task_object , tokenizer , model , engine , num_samples , max_new_tokens , temperature , top_k , max_problems = max_problems )
elif task_object . eval_type == ' categorical ' :
acc = run_categorical_eval ( task_object , tokenizer , model , batch_size , max_problems = max_problems )
else :
raise ValueError ( f " Unsupported task evaluation type: { task_object . eval_type } " )
return acc
# -----------------------------------------------------------------------------
if __name__ == " __main__ " :
# Parse command-line arguments
parser = argparse . ArgumentParser ( )
2026-02-02 17:33:46 +01:00
parser . add_argument ( ' -i ' , ' --source ' , type = str , required = True , help = " Source of the model: sft|rl " )
2025-10-13 06:49:24 -07:00
parser . add_argument ( ' -a ' , ' --task-name ' , type = str , default = None , help = " Task name. Default = all tasks. Use | to split multiple tasks. " )
parser . add_argument ( ' -t ' , ' --temperature ' , type = float , default = 0.0 )
parser . add_argument ( ' -m ' , ' --max-new-tokens ' , type = int , default = 512 )
parser . add_argument ( ' -n ' , ' --num-samples ' , type = int , default = 1 )
parser . add_argument ( ' -k ' , ' --top-k ' , type = int , default = 50 )
parser . add_argument ( ' -b ' , ' --batch-size ' , type = int , default = 8 , help = ' Batch size for categorical evaluation ' )
parser . add_argument ( ' -g ' , ' --model-tag ' , type = str , default = None , help = ' Model tag to load ' )
parser . add_argument ( ' -s ' , ' --step ' , type = int , default = None , help = ' Step to load ' )
parser . add_argument ( ' -x ' , ' --max-problems ' , type = int , default = None , help = ' Max problems to evaluate ' )
2025-10-20 10:15:17 -07:00
parser . add_argument ( ' --device-type ' , type = str , default = ' ' , choices = [ ' cuda ' , ' cpu ' , ' mps ' ] , help = ' Device type for evaluation: cuda|cpu|mps. empty => autodetect ' )
2025-10-13 06:49:24 -07:00
args = parser . parse_args ( )
2025-10-20 10:15:17 -07:00
device_type = autodetect_device_type ( ) if args . device_type == " " else args . device_type
ddp , ddp_rank , ddp_local_rank , ddp_world_size , device = compute_init ( device_type )
2025-10-13 06:49:24 -07:00
model , tokenizer , meta = load_model ( args . source , device , phase = " eval " , model_tag = args . model_tag , step = args . step )
engine = Engine ( model , tokenizer )
# Get the tasks to evaluate on
2025-10-24 14:02:48 +00:00
all_tasks = [ ' ARC-Easy ' , ' ARC-Challenge ' , ' MMLU ' , ' GSM8K ' , ' HumanEval ' , ' SpellingBee ' ]
2025-10-13 06:49:24 -07:00
baseline_accuracies = {
' ARC-Easy ' : 0.25 , # multiple choice 1 of 4 => 25%
' ARC-Challenge ' : 0.25 , # multiple choice 1 of 4 => 25%
' MMLU ' : 0.25 , # multiple choice 1 of 4 => 25%
' GSM8K ' : 0.0 , # open-ended => 0%
' HumanEval ' : 0.0 , # open-ended => 0%
2025-10-24 14:02:48 +00:00
' SpellingBee ' : 0.0 , # open-ended => 0%
2025-10-13 06:49:24 -07:00
}
task_names = all_tasks if args . task_name is None else args . task_name . split ( ' | ' )
# Run all the task evaluations sequentially
results = { }
for task_name in task_names :
2026-03-04 23:55:24 +00:00
acc = run_chat_eval (
task_name ,
model , tokenizer , engine ,
batch_size = args . batch_size ,
num_samples = args . num_samples ,
max_new_tokens = args . max_new_tokens ,
temperature = args . temperature ,
top_k = args . top_k ,
max_problems = args . max_problems ,
)
results [ task_name ] = acc
print0 ( f " { task_name } accuracy: { 100 * acc : .2f } % " )
2025-10-13 06:49:24 -07:00
# Log to report
from nanochat . report import get_report
all_tasks_were_evaluated = all ( task_name in results for task_name in all_tasks )
# calculate the ChatCORE metric if we can (similar to CORE, it's the mean centered accuracy)
# this way, ChatCORE ranges from 0 (at random baseline) to 1 (peak performance)
chatcore_metric_dict = { }
if all_tasks_were_evaluated :
centered_mean = 0
for task_name , acc in results . items ( ) :
baseline_acc = baseline_accuracies . get ( task_name , 0.0 )
centered_acc = ( acc - baseline_acc ) / ( 1.0 - baseline_acc )
centered_mean + = centered_acc
chatcore_metric = centered_mean / len ( results )
chatcore_metric_dict = { " ChatCORE metric " : chatcore_metric }
get_report ( ) . log ( section = " Chat evaluation " + args . source , data = [
vars ( args ) , # CLI args
results ,
chatcore_metric_dict ,
] )
compute_cleanup ( )