session-tools/contrib/apricot-stt-refactor/stt_service.py.postWHISPERHTTP
Natalie 7138338d31 feat(@scripts): add whisper-http backend config and stt service refactor
Co-Authored-By: Lilith Autocommit <noreply@atlilith.com>
2026-05-17 22:11:19 -07:00

313 lines
12 KiB
Text

"""STT service delegating to the model-boss whisper-http backend.
Mirrors the architecture of `tts_service.py`: this process owns no Whisper
model and no VRAM. All transcription requests are proxied to the
`whisper-http` service (which acquires a model-boss VRAM lease around its
own model lifecycle, coordinating with TTS to prevent OOM).
Public surface is preserved 1:1 with the old in-process implementation so
the existing routes (routes/stt.py) and websocket streamers don't need
changes — only the internals are different.
"""
from __future__ import annotations
import asyncio
import base64
import logging
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING, Any, Literal
if TYPE_CHECKING:
from chatterbox_tts_service.config import ChatterboxSettings
logger = logging.getLogger(__name__)
# Available Whisper models (the set whisper-http accepts via WhisperLoader)
WHISPER_MODELS = [
"tiny", "tiny.en",
"base", "base.en",
"small", "small.en",
"medium", "medium.en",
"large-v2", "large-v3",
"turbo",
]
@dataclass
class TranscriptionSegment:
"""A single segment of transcribed audio."""
start: float
end: float
text: str
confidence: float | None = None
no_speech_prob: float | None = None
@dataclass
class TranscriptionResult:
"""Result of a transcription operation."""
text: str
language: str
language_probability: float
duration_seconds: float
segments: list[TranscriptionSegment] = field(default_factory=list)
average_confidence: float | None = None
model_used: str = ""
class STTService:
"""Speech-to-Text service backed by the whisper-http endpoint.
No VRAM is held by this process — model lifecycle is owned by model-boss
via whisper-http. State tracking (`current_model`, `is_model_loaded`)
reflects the LAST successful HTTP transcription, not local memory.
"""
def __init__(self, settings: ChatterboxSettings) -> None:
self.settings = settings
# whisper_http_url is read off the settings; falls back to env or
# localhost:10011 (matches the systemd unit / router.py default).
self._whisper_http_url: str = getattr(
settings, "whisper_http_url", None
) or "http://localhost:10011"
self._http_client: Any | None = None
self._last_model: str | None = None
# Cached health snapshot so `device` / `is_model_loaded` can return
# meaningful values without doing a network round-trip on every read.
self._cached_health: dict[str, Any] = {}
self._health_lock = asyncio.Lock()
# ─── Public properties (preserved API) ────────────────────────────────────
@property
def is_model_loaded(self) -> bool:
"""Whether whisper-http has reported a loaded model in its last health check.
Returns False until the first successful health check / transcription.
"""
return bool(self._cached_health.get("model_loaded"))
@property
def current_model(self) -> str | None:
"""Last model that whisper-http reported as loaded."""
return self._cached_health.get("model_id") or self._last_model
@property
def device(self) -> str:
"""Device whisper-http reports running on. 'remote' if no health snapshot yet."""
return self._cached_health.get("device") or "remote"
@property
def available_models(self) -> list[str]:
return WHISPER_MODELS.copy()
# ─── Lifecycle (no-ops — backend owns model) ──────────────────────────────
async def load_model(self, model_name: str = "base") -> None:
"""Request whisper-http to warm a model. Validates the name and calls /health.
whisper-http loads lazily on the first /transcribe; this method exists
for API compatibility — callers that want to pre-warm can pass a 1s
synthetic clip through transcribe_bytes() instead.
"""
if model_name not in WHISPER_MODELS:
raise ValueError(
f"Invalid model '{model_name}'. Available: {', '.join(WHISPER_MODELS)}"
)
self._last_model = model_name
await self._refresh_health()
async def get_model(self, model_name: str = "base") -> None:
"""No-op — model lives in whisper-http. Validates the name."""
if model_name not in WHISPER_MODELS:
raise ValueError(
f"Invalid model '{model_name}'. Available: {', '.join(WHISPER_MODELS)}"
)
self._last_model = model_name
return None
async def unload_model(self) -> None:
"""No-op — whisper-http manages model unload via its lease lifecycle."""
self._cached_health.pop("model_loaded", None)
self._cached_health.pop("model_id", None)
def get_model_info(self) -> dict[str, Any]:
return {
"backend": "whisper-http",
"url": self._whisper_http_url,
"current_model": self.current_model,
"device": self.device,
"is_model_loaded": self.is_model_loaded,
"available_models": self.available_models,
"health": self._cached_health,
}
# ─── Transcription (proxies to whisper-http) ──────────────────────────────
async def transcribe(
self,
audio_path: Path | str,
*,
model: str = "base",
language: str | None = None,
task: Literal["transcribe", "translate"] = "transcribe",
temperature: float = 0.0, # accepted for API parity; passed through
beam_size: int = 5,
best_of: int = 5,
patience: float = 1.0, # accepted, not forwarded (loader default)
length_penalty: float = 1.0, # accepted, not forwarded
initial_prompt: str | None = None,
word_timestamps: bool = False,
vad_filter: bool = False,
vad_parameters: dict[str, Any] | None = None, # accepted, not forwarded
**kwargs: Any,
) -> TranscriptionResult:
"""Transcribe an audio file via whisper-http. Reads + base64-encodes the file."""
path = Path(audio_path)
if not path.exists():
raise FileNotFoundError(f"Audio file not found: {path}")
audio_bytes = await asyncio.to_thread(path.read_bytes)
return await self.transcribe_bytes(
audio_bytes,
model=model,
language=language,
task=task,
beam_size=beam_size,
best_of=best_of,
initial_prompt=initial_prompt,
word_timestamps=word_timestamps,
vad_filter=vad_filter,
**kwargs,
)
async def transcribe_bytes(
self,
audio_bytes: bytes,
*,
model: str = "base",
language: str | None = None,
task: Literal["transcribe", "translate"] = "transcribe",
beam_size: int = 5,
best_of: int = 5,
initial_prompt: str | None = None,
word_timestamps: bool = False,
vad_filter: bool = False,
_retry: bool = False,
**kwargs: Any,
) -> TranscriptionResult:
if model not in WHISPER_MODELS:
raise ValueError(
f"Invalid model '{model}'. Available: {', '.join(WHISPER_MODELS)}"
)
if not audio_bytes:
raise ValueError("transcribe_bytes called with empty audio")
import httpx
payload: dict[str, Any] = {
"audio": base64.b64encode(audio_bytes).decode("ascii"),
"model": model,
"task": task,
"beam_size": beam_size,
"best_of": best_of,
"word_timestamps": word_timestamps,
"vad_filter": vad_filter,
}
if language:
payload["language"] = language
if initial_prompt:
payload["initial_prompt"] = initial_prompt
client = self._get_http_client()
url = f"{self._whisper_http_url}/transcribe"
try:
response = await client.post(url, json=payload)
except (httpx.ConnectError, httpx.RemoteProtocolError) as exc:
if not _retry:
logger.warning("whisper-http connection error, retrying once: %s", exc)
return await self.transcribe_bytes(
audio_bytes,
model=model,
language=language,
task=task,
beam_size=beam_size,
best_of=best_of,
initial_prompt=initial_prompt,
word_timestamps=word_timestamps,
vad_filter=vad_filter,
_retry=True,
)
raise RuntimeError(f"whisper-http request failed: {exc}") from exc
if response.status_code != 200:
raise RuntimeError(
f"whisper-http returned {response.status_code}: {response.text[:500]}"
)
data = response.json()
segments = [
TranscriptionSegment(
start=float(s["start"]),
end=float(s["end"]),
text=str(s["text"]),
confidence=(float(s["confidence"]) if s.get("confidence") is not None else None),
no_speech_prob=(
float(s["no_speech_prob"]) if s.get("no_speech_prob") is not None else None
),
)
for s in data.get("segments", [])
]
confidences = [s.confidence for s in segments if s.confidence is not None]
avg_conf = sum(confidences) / len(confidences) if confidences else None
result = TranscriptionResult(
text=str(data.get("text", "")),
language=str(data.get("language", language or "unknown")),
language_probability=float(data.get("language_probability", 1.0)),
duration_seconds=float(data.get("duration_seconds", 0.0)),
segments=segments,
average_confidence=avg_conf,
model_used=str(data.get("model_used", model)),
)
self._last_model = result.model_used or model
# Opportunistically refresh health from the successful response so
# is_model_loaded / device reflect reality without an extra round-trip.
self._cached_health["model_loaded"] = True
self._cached_health["model_id"] = self._last_model
return result
async def cleanup(self) -> None:
if self._http_client is not None:
await self._http_client.aclose()
self._http_client = None
# ─── Private ──────────────────────────────────────────────────────────────
def _get_http_client(self) -> Any:
import httpx
if self._http_client is None:
self._http_client = httpx.AsyncClient(
timeout=httpx.Timeout(180.0, connect=10.0),
)
return self._http_client
async def _refresh_health(self) -> None:
"""Fetch /health from whisper-http and cache. Tolerates errors silently."""
import httpx
async with self._health_lock:
client = self._get_http_client()
try:
response = await client.get(f"{self._whisper_http_url}/health", timeout=5.0)
if response.status_code == 200:
self._cached_health = response.json()
except (httpx.ConnectError, httpx.HTTPError, Exception) as exc: # noqa: BLE001
logger.debug("whisper-http health probe failed: %s", exc)