117 lines
3.9 KiB
Python
117 lines
3.9 KiB
Python
|
|
"""
|
||
|
|
Audio modality for nanochat-omni (W1).
|
||
|
|
|
||
|
|
Frozen Whisper encoder produces soft tokens; Projector maps them into nanochat's
|
||
|
|
residual stream (n_embd) so they can be prepended to text token embeddings
|
||
|
|
LLaVA-style. Output remains text-only.
|
||
|
|
|
||
|
|
Weights:
|
||
|
|
- ModelScope first when WHISPER_MS_ID is set (e.g. iic/Whisper-small,
|
||
|
|
iic/Whisper-large-v3) — preferred path on CN boxes (ailab/zy/etc).
|
||
|
|
- HuggingFace fallback (honors HF_ENDPOINT for hf-mirror).
|
||
|
|
|
||
|
|
The encoder is held frozen; only Projector is trained.
|
||
|
|
"""
|
||
|
|
|
||
|
|
import os
|
||
|
|
|
||
|
|
import torch
|
||
|
|
import torch.nn as nn
|
||
|
|
import torch.nn.functional as F
|
||
|
|
|
||
|
|
from nanochat.gpt import Linear
|
||
|
|
|
||
|
|
|
||
|
|
def _load_whisper_via_modelscope(ms_id):
|
||
|
|
from modelscope import snapshot_download
|
||
|
|
local_path = snapshot_download(ms_id)
|
||
|
|
from transformers import WhisperModel, WhisperFeatureExtractor
|
||
|
|
extractor = WhisperFeatureExtractor.from_pretrained(local_path)
|
||
|
|
model = WhisperModel.from_pretrained(local_path)
|
||
|
|
return extractor, model.encoder
|
||
|
|
|
||
|
|
|
||
|
|
def _load_whisper_via_hf(hf_id):
|
||
|
|
from transformers import WhisperModel, WhisperFeatureExtractor
|
||
|
|
extractor = WhisperFeatureExtractor.from_pretrained(hf_id)
|
||
|
|
model = WhisperModel.from_pretrained(hf_id)
|
||
|
|
return extractor, model.encoder
|
||
|
|
|
||
|
|
|
||
|
|
def load_whisper(hf_id="openai/whisper-base", ms_id=None):
|
||
|
|
"""Load (feature_extractor, encoder). Tries ModelScope if ms_id is given,
|
||
|
|
falls back to HuggingFace. Returns the .encoder submodule (no decoder)."""
|
||
|
|
ms_id = ms_id or os.environ.get("WHISPER_MS_ID")
|
||
|
|
hf_id = os.environ.get("WHISPER_HF_ID", hf_id)
|
||
|
|
errors = []
|
||
|
|
if ms_id:
|
||
|
|
try:
|
||
|
|
return _load_whisper_via_modelscope(ms_id)
|
||
|
|
except Exception as e:
|
||
|
|
errors.append(f"modelscope({ms_id}): {e}")
|
||
|
|
try:
|
||
|
|
return _load_whisper_via_hf(hf_id)
|
||
|
|
except Exception as e:
|
||
|
|
errors.append(f"hf({hf_id}): {e}")
|
||
|
|
raise RuntimeError("Failed to load Whisper encoder. Tried: " + " | ".join(errors))
|
||
|
|
|
||
|
|
|
||
|
|
class WhisperEncoder(nn.Module):
|
||
|
|
"""Frozen Whisper encoder. Forward takes log-mel input_features
|
||
|
|
(B, n_mels, T_mel) and returns (B, T_enc, d_model)."""
|
||
|
|
|
||
|
|
def __init__(self, hf_id="openai/whisper-base", ms_id=None, device=None, dtype=None):
|
||
|
|
super().__init__()
|
||
|
|
extractor, encoder = load_whisper(hf_id=hf_id, ms_id=ms_id)
|
||
|
|
self.feature_extractor = extractor
|
||
|
|
self.encoder = encoder
|
||
|
|
for p in self.encoder.parameters():
|
||
|
|
p.requires_grad = False
|
||
|
|
self.encoder.eval()
|
||
|
|
self._d_model = encoder.config.d_model
|
||
|
|
self.sampling_rate = extractor.sampling_rate
|
||
|
|
if device is not None or dtype is not None:
|
||
|
|
self.encoder.to(device=device, dtype=dtype)
|
||
|
|
|
||
|
|
@property
|
||
|
|
def d_model(self):
|
||
|
|
return self._d_model
|
||
|
|
|
||
|
|
def preprocess(self, audio_arrays):
|
||
|
|
"""audio_arrays: list of 1D np.float32 (mono, sampling_rate Hz).
|
||
|
|
Returns input_features tensor (B, n_mels, T_mel)."""
|
||
|
|
out = self.feature_extractor(
|
||
|
|
audio_arrays,
|
||
|
|
sampling_rate=self.sampling_rate,
|
||
|
|
return_tensors="pt",
|
||
|
|
)
|
||
|
|
return out.input_features
|
||
|
|
|
||
|
|
@torch.no_grad()
|
||
|
|
def forward(self, input_features):
|
||
|
|
out = self.encoder(input_features=input_features)
|
||
|
|
return out.last_hidden_state
|
||
|
|
|
||
|
|
|
||
|
|
class Projector(nn.Module):
|
||
|
|
"""LLaVA-style 2-layer MLP: audio_d -> hidden -> n_embd.
|
||
|
|
|
||
|
|
Uses nanochat's Linear so master weights stay fp32 while forward runs in
|
||
|
|
the activation dtype (typically bf16). Matches the convention in gpt.py.
|
||
|
|
"""
|
||
|
|
|
||
|
|
def __init__(self, in_dim, out_dim, hidden_dim=None):
|
||
|
|
super().__init__()
|
||
|
|
hidden_dim = hidden_dim or out_dim
|
||
|
|
self.fc1 = Linear(in_dim, hidden_dim, bias=False)
|
||
|
|
self.fc2 = Linear(hidden_dim, out_dim, bias=False)
|
||
|
|
s = (3.0 / in_dim) ** 0.5
|
||
|
|
torch.nn.init.uniform_(self.fc1.weight, -s, s)
|
||
|
|
torch.nn.init.zeros_(self.fc2.weight)
|
||
|
|
|
||
|
|
def forward(self, x):
|
||
|
|
x = self.fc1(x)
|
||
|
|
x = F.gelu(x)
|
||
|
|
x = self.fc2(x)
|
||
|
|
return x
|