session-tools/contrib/apricot-stt-refactor/stt_service.py.preWHISPERHTTP
Natalie 7138338d31 feat(@scripts): add whisper-http backend config and stt service refactor
Co-Authored-By: Lilith Autocommit <noreply@atlilith.com>
2026-05-17 22:11:19 -07:00

406 lines
13 KiB
Text

"""Core STT service using faster-whisper for GPU-accelerated transcription."""
from __future__ import annotations
import asyncio
import logging
import tempfile
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Any, Literal
if TYPE_CHECKING:
from chatterbox_tts_service.config import ChatterboxSettings
from gpu_devices import is_cuda_available
logger = logging.getLogger(__name__)
# Available Whisper models
WHISPER_MODELS = ["tiny", "tiny.en", "base", "base.en", "small", "small.en", "medium", "medium.en", "large-v2", "large-v3"]
@dataclass
class TranscriptionSegment:
"""A single segment of transcribed audio."""
start: float
end: float
text: str
confidence: float | None = None
no_speech_prob: float | None = None
@dataclass
class TranscriptionResult:
"""Result of a transcription operation."""
text: str
language: str
language_probability: float
duration_seconds: float
segments: list[TranscriptionSegment]
average_confidence: float | None = None
class STTService:
"""Speech-to-Text service using faster-whisper.
Provides lazy model loading, GPU acceleration, and high-quality transcription.
"""
def __init__(self, settings: ChatterboxSettings) -> None:
"""Initialize the STT service.
Args:
settings: Service configuration.
"""
self.settings = settings
self._model: Any | None = None
self._current_model_name: str | None = None
self._model_lock = asyncio.Lock()
self._device: str = "cpu"
self._compute_type: str = "int8"
# Determine device and compute type
if is_cuda_available():
self._device = "cuda"
self._compute_type = "float16" # Use fp16 on GPU for speed
logger.info("STT service will use CUDA acceleration")
else:
self._device = "cpu"
self._compute_type = "int8" # Use int8 quantization on CPU
logger.info("STT service will use CPU (int8 quantization)")
logger.info(
f"STTService initialized: device={self._device}, compute_type={self._compute_type}"
)
@property
def is_model_loaded(self) -> bool:
"""Check if a model is loaded."""
return self._model is not None
@property
def current_model(self) -> str | None:
"""Get the currently loaded model name."""
return self._current_model_name
@property
def device(self) -> str:
"""Get the current compute device."""
return self._device
@property
def available_models(self) -> list[str]:
"""Get list of available Whisper models."""
return WHISPER_MODELS.copy()
async def load_model(self, model_name: str = "base") -> None:
"""Load a Whisper model.
Args:
model_name: Name of the Whisper model to load.
Raises:
ValueError: If model_name is not a valid Whisper model.
"""
if model_name not in WHISPER_MODELS:
raise ValueError(
f"Invalid model '{model_name}'. Available models: {', '.join(WHISPER_MODELS)}"
)
# If same model already loaded, skip
if self._model is not None and self._current_model_name == model_name:
logger.debug(f"Model '{model_name}' already loaded, skipping")
return
async with self._model_lock:
# Double-check after acquiring lock
if self._model is not None and self._current_model_name == model_name:
return
logger.info(f"Loading Whisper model: {model_name} on {self._device}")
# Unload existing model if different
if self._model is not None and self._current_model_name != model_name:
logger.info(f"Unloading previous model: {self._current_model_name}")
del self._model
self._model = None
self._current_model_name = None
# Load model in thread pool to avoid blocking
self._model = await asyncio.to_thread(
self._load_model_sync, model_name
)
self._current_model_name = model_name
logger.info(f"Whisper model '{model_name}' loaded successfully")
def _load_model_sync(self, model_name: str) -> Any:
"""Synchronously load the model (called in thread pool).
Args:
model_name: Name of the Whisper model to load.
Returns:
Loaded WhisperModel instance.
"""
from faster_whisper import WhisperModel
model = WhisperModel(
model_name,
device=self._device,
compute_type=self._compute_type,
download_root=str(self.settings.model_cache_dir),
)
return model
async def get_model(self, model_name: str = "base") -> Any:
"""Get the loaded model, loading if necessary.
Args:
model_name: Name of the Whisper model to use.
Returns:
The WhisperModel instance.
"""
if self._model is None or self._current_model_name != model_name:
await self.load_model(model_name)
return self._model
async def transcribe(
self,
audio_path: Path | str,
*,
model: str = "base",
language: str | None = None,
task: Literal["transcribe", "translate"] = "transcribe",
temperature: float = 0.0,
beam_size: int = 5,
best_of: int = 5,
patience: float = 1.0,
vad_filter: bool = True,
word_timestamps: bool = False,
) -> TranscriptionResult:
"""Transcribe audio file to text.
Args:
audio_path: Path to audio file (supports WAV, MP3, WebM, etc.).
model: Whisper model name to use.
language: Language code (e.g., 'en', 'de'). Auto-detect if None.
task: Either 'transcribe' (keep original language) or 'translate' (to English).
temperature: Sampling temperature (0.0 = deterministic).
beam_size: Beam size for beam search.
best_of: Number of candidates when sampling.
patience: Patience value for beam search.
vad_filter: Enable voice activity detection filter.
word_timestamps: Generate word-level timestamps.
Returns:
TranscriptionResult with text, language, and segments.
Raises:
FileNotFoundError: If audio file doesn't exist.
ValueError: If invalid parameters provided.
"""
audio_path = Path(audio_path)
if not audio_path.exists():
raise FileNotFoundError(f"Audio file not found: {audio_path}")
# Get or load model
whisper_model = await self.get_model(model)
logger.info(
f"Transcribing: file={audio_path.name}, model={model}, "
f"language={language or 'auto'}, task={task}"
)
# Run transcription in thread pool
result = await asyncio.to_thread(
self._transcribe_sync,
whisper_model,
audio_path,
language,
task,
temperature,
beam_size,
best_of,
patience,
vad_filter,
word_timestamps,
)
logger.info(
f"Transcription complete: text_len={len(result.text)}, "
f"language={result.language}, duration={result.duration_seconds:.2f}s"
)
return result
def _transcribe_sync(
self,
model: Any,
audio_path: Path,
language: str | None,
task: str,
temperature: float,
beam_size: int,
best_of: int,
patience: float,
vad_filter: bool,
word_timestamps: bool,
) -> TranscriptionResult:
"""Synchronously transcribe audio (called in thread pool).
Args:
model: WhisperModel instance.
audio_path: Path to audio file.
language: Language code or None for auto-detection.
task: 'transcribe' or 'translate'.
temperature: Sampling temperature.
beam_size: Beam size for beam search.
best_of: Number of candidates when sampling.
patience: Patience value for beam search.
vad_filter: Enable VAD filter.
word_timestamps: Generate word-level timestamps.
Returns:
TranscriptionResult.
"""
# Run transcription
segments_iter, info = model.transcribe(
str(audio_path),
language=language,
task=task,
temperature=temperature,
beam_size=beam_size,
best_of=best_of,
patience=patience,
vad_filter=vad_filter,
word_timestamps=word_timestamps,
)
# Collect segments
segments = []
full_text = []
confidences = []
for segment in segments_iter:
# Extract confidence (average of word probabilities if available)
confidence = None
if hasattr(segment, "avg_logprob"):
# Convert log probability to probability
import math
confidence = math.exp(segment.avg_logprob)
confidences.append(confidence)
no_speech_prob = getattr(segment, "no_speech_prob", None)
segments.append(
TranscriptionSegment(
start=segment.start,
end=segment.end,
text=segment.text.strip(),
confidence=confidence,
no_speech_prob=no_speech_prob,
)
)
full_text.append(segment.text.strip())
# Calculate average confidence
avg_confidence = None
if confidences:
avg_confidence = sum(confidences) / len(confidences)
return TranscriptionResult(
text=" ".join(full_text),
language=info.language,
language_probability=info.language_probability,
duration_seconds=info.duration,
segments=segments,
average_confidence=avg_confidence,
)
async def transcribe_bytes(
self,
audio_bytes: bytes,
*,
model: str = "base",
language: str | None = None,
task: Literal["transcribe", "translate"] = "transcribe",
**kwargs,
) -> TranscriptionResult:
"""Transcribe audio from bytes.
Saves bytes to temporary file and transcribes.
Args:
audio_bytes: Raw audio file bytes.
model: Whisper model name to use.
language: Language code or None for auto-detection.
task: Either 'transcribe' or 'translate'.
**kwargs: Additional arguments passed to transcribe().
Returns:
TranscriptionResult.
"""
# Write to temporary file
with tempfile.NamedTemporaryFile(suffix=".audio", delete=False) as tmp:
tmp.write(audio_bytes)
tmp_path = Path(tmp.name)
try:
# Transcribe from temporary file
result = await self.transcribe(
tmp_path,
model=model,
language=language,
task=task,
**kwargs,
)
return result
finally:
# Clean up temporary file
try:
tmp_path.unlink()
except Exception as e:
logger.warning(f"Failed to delete temporary file {tmp_path}: {e}")
async def unload_model(self) -> None:
"""Unload the current model and free memory."""
async with self._model_lock:
if self._model is not None:
logger.info(f"Unloading Whisper model: {self._current_model_name}")
# Delete model
del self._model
self._model = None
self._current_model_name = None
# Clear CUDA cache if on GPU
if self._device == "cuda":
try:
import torch
torch.cuda.empty_cache()
logger.info("CUDA cache cleared")
except Exception as e:
logger.warning(f"Failed to clear CUDA cache: {e}")
logger.info("Model unloaded successfully")
def get_model_info(self) -> dict[str, Any]:
"""Get information about the STT service state.
Returns:
Dictionary with service status information.
"""
return {
"model_loaded": self.is_model_loaded,
"current_model": self._current_model_name,
"device": self._device,
"compute_type": self._compute_type,
"available_models": self.available_models,
}