313 lines
12 KiB
Text
313 lines
12 KiB
Text
"""STT service delegating to the model-boss whisper-http backend.
|
|
|
|
Mirrors the architecture of `tts_service.py`: this process owns no Whisper
|
|
model and no VRAM. All transcription requests are proxied to the
|
|
`whisper-http` service (which acquires a model-boss VRAM lease around its
|
|
own model lifecycle, coordinating with TTS to prevent OOM).
|
|
|
|
Public surface is preserved 1:1 with the old in-process implementation so
|
|
the existing routes (routes/stt.py) and websocket streamers don't need
|
|
changes — only the internals are different.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import base64
|
|
import logging
|
|
from dataclasses import dataclass, field
|
|
from pathlib import Path
|
|
from typing import TYPE_CHECKING, Any, Literal
|
|
|
|
if TYPE_CHECKING:
|
|
from chatterbox_tts_service.config import ChatterboxSettings
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
# Available Whisper models (the set whisper-http accepts via WhisperLoader)
|
|
WHISPER_MODELS = [
|
|
"tiny", "tiny.en",
|
|
"base", "base.en",
|
|
"small", "small.en",
|
|
"medium", "medium.en",
|
|
"large-v2", "large-v3",
|
|
"turbo",
|
|
]
|
|
|
|
|
|
@dataclass
|
|
class TranscriptionSegment:
|
|
"""A single segment of transcribed audio."""
|
|
|
|
start: float
|
|
end: float
|
|
text: str
|
|
confidence: float | None = None
|
|
no_speech_prob: float | None = None
|
|
|
|
|
|
@dataclass
|
|
class TranscriptionResult:
|
|
"""Result of a transcription operation."""
|
|
|
|
text: str
|
|
language: str
|
|
language_probability: float
|
|
duration_seconds: float
|
|
segments: list[TranscriptionSegment] = field(default_factory=list)
|
|
average_confidence: float | None = None
|
|
model_used: str = ""
|
|
|
|
|
|
class STTService:
|
|
"""Speech-to-Text service backed by the whisper-http endpoint.
|
|
|
|
No VRAM is held by this process — model lifecycle is owned by model-boss
|
|
via whisper-http. State tracking (`current_model`, `is_model_loaded`)
|
|
reflects the LAST successful HTTP transcription, not local memory.
|
|
"""
|
|
|
|
def __init__(self, settings: ChatterboxSettings) -> None:
|
|
self.settings = settings
|
|
# whisper_http_url is read off the settings; falls back to env or
|
|
# localhost:10011 (matches the systemd unit / router.py default).
|
|
self._whisper_http_url: str = getattr(
|
|
settings, "whisper_http_url", None
|
|
) or "http://localhost:10011"
|
|
self._http_client: Any | None = None
|
|
self._last_model: str | None = None
|
|
# Cached health snapshot so `device` / `is_model_loaded` can return
|
|
# meaningful values without doing a network round-trip on every read.
|
|
self._cached_health: dict[str, Any] = {}
|
|
self._health_lock = asyncio.Lock()
|
|
|
|
# ─── Public properties (preserved API) ────────────────────────────────────
|
|
|
|
@property
|
|
def is_model_loaded(self) -> bool:
|
|
"""Whether whisper-http has reported a loaded model in its last health check.
|
|
|
|
Returns False until the first successful health check / transcription.
|
|
"""
|
|
return bool(self._cached_health.get("model_loaded"))
|
|
|
|
@property
|
|
def current_model(self) -> str | None:
|
|
"""Last model that whisper-http reported as loaded."""
|
|
return self._cached_health.get("model_id") or self._last_model
|
|
|
|
@property
|
|
def device(self) -> str:
|
|
"""Device whisper-http reports running on. 'remote' if no health snapshot yet."""
|
|
return self._cached_health.get("device") or "remote"
|
|
|
|
@property
|
|
def available_models(self) -> list[str]:
|
|
return WHISPER_MODELS.copy()
|
|
|
|
# ─── Lifecycle (no-ops — backend owns model) ──────────────────────────────
|
|
|
|
async def load_model(self, model_name: str = "base") -> None:
|
|
"""Request whisper-http to warm a model. Validates the name and calls /health.
|
|
|
|
whisper-http loads lazily on the first /transcribe; this method exists
|
|
for API compatibility — callers that want to pre-warm can pass a 1s
|
|
synthetic clip through transcribe_bytes() instead.
|
|
"""
|
|
if model_name not in WHISPER_MODELS:
|
|
raise ValueError(
|
|
f"Invalid model '{model_name}'. Available: {', '.join(WHISPER_MODELS)}"
|
|
)
|
|
self._last_model = model_name
|
|
await self._refresh_health()
|
|
|
|
async def get_model(self, model_name: str = "base") -> None:
|
|
"""No-op — model lives in whisper-http. Validates the name."""
|
|
if model_name not in WHISPER_MODELS:
|
|
raise ValueError(
|
|
f"Invalid model '{model_name}'. Available: {', '.join(WHISPER_MODELS)}"
|
|
)
|
|
self._last_model = model_name
|
|
return None
|
|
|
|
async def unload_model(self) -> None:
|
|
"""No-op — whisper-http manages model unload via its lease lifecycle."""
|
|
self._cached_health.pop("model_loaded", None)
|
|
self._cached_health.pop("model_id", None)
|
|
|
|
def get_model_info(self) -> dict[str, Any]:
|
|
return {
|
|
"backend": "whisper-http",
|
|
"url": self._whisper_http_url,
|
|
"current_model": self.current_model,
|
|
"device": self.device,
|
|
"is_model_loaded": self.is_model_loaded,
|
|
"available_models": self.available_models,
|
|
"health": self._cached_health,
|
|
}
|
|
|
|
# ─── Transcription (proxies to whisper-http) ──────────────────────────────
|
|
|
|
async def transcribe(
|
|
self,
|
|
audio_path: Path | str,
|
|
*,
|
|
model: str = "base",
|
|
language: str | None = None,
|
|
task: Literal["transcribe", "translate"] = "transcribe",
|
|
temperature: float = 0.0, # accepted for API parity; passed through
|
|
beam_size: int = 5,
|
|
best_of: int = 5,
|
|
patience: float = 1.0, # accepted, not forwarded (loader default)
|
|
length_penalty: float = 1.0, # accepted, not forwarded
|
|
initial_prompt: str | None = None,
|
|
word_timestamps: bool = False,
|
|
vad_filter: bool = False,
|
|
vad_parameters: dict[str, Any] | None = None, # accepted, not forwarded
|
|
**kwargs: Any,
|
|
) -> TranscriptionResult:
|
|
"""Transcribe an audio file via whisper-http. Reads + base64-encodes the file."""
|
|
path = Path(audio_path)
|
|
if not path.exists():
|
|
raise FileNotFoundError(f"Audio file not found: {path}")
|
|
audio_bytes = await asyncio.to_thread(path.read_bytes)
|
|
return await self.transcribe_bytes(
|
|
audio_bytes,
|
|
model=model,
|
|
language=language,
|
|
task=task,
|
|
beam_size=beam_size,
|
|
best_of=best_of,
|
|
initial_prompt=initial_prompt,
|
|
word_timestamps=word_timestamps,
|
|
vad_filter=vad_filter,
|
|
**kwargs,
|
|
)
|
|
|
|
async def transcribe_bytes(
|
|
self,
|
|
audio_bytes: bytes,
|
|
*,
|
|
model: str = "base",
|
|
language: str | None = None,
|
|
task: Literal["transcribe", "translate"] = "transcribe",
|
|
beam_size: int = 5,
|
|
best_of: int = 5,
|
|
initial_prompt: str | None = None,
|
|
word_timestamps: bool = False,
|
|
vad_filter: bool = False,
|
|
_retry: bool = False,
|
|
**kwargs: Any,
|
|
) -> TranscriptionResult:
|
|
if model not in WHISPER_MODELS:
|
|
raise ValueError(
|
|
f"Invalid model '{model}'. Available: {', '.join(WHISPER_MODELS)}"
|
|
)
|
|
if not audio_bytes:
|
|
raise ValueError("transcribe_bytes called with empty audio")
|
|
|
|
import httpx
|
|
|
|
payload: dict[str, Any] = {
|
|
"audio": base64.b64encode(audio_bytes).decode("ascii"),
|
|
"model": model,
|
|
"task": task,
|
|
"beam_size": beam_size,
|
|
"best_of": best_of,
|
|
"word_timestamps": word_timestamps,
|
|
"vad_filter": vad_filter,
|
|
}
|
|
if language:
|
|
payload["language"] = language
|
|
if initial_prompt:
|
|
payload["initial_prompt"] = initial_prompt
|
|
|
|
client = self._get_http_client()
|
|
url = f"{self._whisper_http_url}/transcribe"
|
|
|
|
try:
|
|
response = await client.post(url, json=payload)
|
|
except (httpx.ConnectError, httpx.RemoteProtocolError) as exc:
|
|
if not _retry:
|
|
logger.warning("whisper-http connection error, retrying once: %s", exc)
|
|
return await self.transcribe_bytes(
|
|
audio_bytes,
|
|
model=model,
|
|
language=language,
|
|
task=task,
|
|
beam_size=beam_size,
|
|
best_of=best_of,
|
|
initial_prompt=initial_prompt,
|
|
word_timestamps=word_timestamps,
|
|
vad_filter=vad_filter,
|
|
_retry=True,
|
|
)
|
|
raise RuntimeError(f"whisper-http request failed: {exc}") from exc
|
|
|
|
if response.status_code != 200:
|
|
raise RuntimeError(
|
|
f"whisper-http returned {response.status_code}: {response.text[:500]}"
|
|
)
|
|
|
|
data = response.json()
|
|
segments = [
|
|
TranscriptionSegment(
|
|
start=float(s["start"]),
|
|
end=float(s["end"]),
|
|
text=str(s["text"]),
|
|
confidence=(float(s["confidence"]) if s.get("confidence") is not None else None),
|
|
no_speech_prob=(
|
|
float(s["no_speech_prob"]) if s.get("no_speech_prob") is not None else None
|
|
),
|
|
)
|
|
for s in data.get("segments", [])
|
|
]
|
|
|
|
confidences = [s.confidence for s in segments if s.confidence is not None]
|
|
avg_conf = sum(confidences) / len(confidences) if confidences else None
|
|
|
|
result = TranscriptionResult(
|
|
text=str(data.get("text", "")),
|
|
language=str(data.get("language", language or "unknown")),
|
|
language_probability=float(data.get("language_probability", 1.0)),
|
|
duration_seconds=float(data.get("duration_seconds", 0.0)),
|
|
segments=segments,
|
|
average_confidence=avg_conf,
|
|
model_used=str(data.get("model_used", model)),
|
|
)
|
|
|
|
self._last_model = result.model_used or model
|
|
# Opportunistically refresh health from the successful response so
|
|
# is_model_loaded / device reflect reality without an extra round-trip.
|
|
self._cached_health["model_loaded"] = True
|
|
self._cached_health["model_id"] = self._last_model
|
|
return result
|
|
|
|
async def cleanup(self) -> None:
|
|
if self._http_client is not None:
|
|
await self._http_client.aclose()
|
|
self._http_client = None
|
|
|
|
# ─── Private ──────────────────────────────────────────────────────────────
|
|
|
|
def _get_http_client(self) -> Any:
|
|
import httpx
|
|
|
|
if self._http_client is None:
|
|
self._http_client = httpx.AsyncClient(
|
|
timeout=httpx.Timeout(180.0, connect=10.0),
|
|
)
|
|
return self._http_client
|
|
|
|
async def _refresh_health(self) -> None:
|
|
"""Fetch /health from whisper-http and cache. Tolerates errors silently."""
|
|
import httpx
|
|
|
|
async with self._health_lock:
|
|
client = self._get_http_client()
|
|
try:
|
|
response = await client.get(f"{self._whisper_http_url}/health", timeout=5.0)
|
|
if response.status_code == 200:
|
|
self._cached_health = response.json()
|
|
except (httpx.ConnectError, httpx.HTTPError, Exception) as exc: # noqa: BLE001
|
|
logger.debug("whisper-http health probe failed: %s", exc)
|