406 lines
13 KiB
Text
406 lines
13 KiB
Text
"""Core STT service using faster-whisper for GPU-accelerated transcription."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import logging
|
|
import tempfile
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
from typing import TYPE_CHECKING, Any, Literal
|
|
|
|
if TYPE_CHECKING:
|
|
from chatterbox_tts_service.config import ChatterboxSettings
|
|
|
|
from gpu_devices import is_cuda_available
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
# Available Whisper models
|
|
WHISPER_MODELS = ["tiny", "tiny.en", "base", "base.en", "small", "small.en", "medium", "medium.en", "large-v2", "large-v3"]
|
|
|
|
|
|
@dataclass
|
|
class TranscriptionSegment:
|
|
"""A single segment of transcribed audio."""
|
|
|
|
start: float
|
|
end: float
|
|
text: str
|
|
confidence: float | None = None
|
|
no_speech_prob: float | None = None
|
|
|
|
|
|
@dataclass
|
|
class TranscriptionResult:
|
|
"""Result of a transcription operation."""
|
|
|
|
text: str
|
|
language: str
|
|
language_probability: float
|
|
duration_seconds: float
|
|
segments: list[TranscriptionSegment]
|
|
average_confidence: float | None = None
|
|
|
|
|
|
class STTService:
|
|
"""Speech-to-Text service using faster-whisper.
|
|
|
|
Provides lazy model loading, GPU acceleration, and high-quality transcription.
|
|
"""
|
|
|
|
def __init__(self, settings: ChatterboxSettings) -> None:
|
|
"""Initialize the STT service.
|
|
|
|
Args:
|
|
settings: Service configuration.
|
|
"""
|
|
self.settings = settings
|
|
self._model: Any | None = None
|
|
self._current_model_name: str | None = None
|
|
self._model_lock = asyncio.Lock()
|
|
self._device: str = "cpu"
|
|
self._compute_type: str = "int8"
|
|
|
|
# Determine device and compute type
|
|
if is_cuda_available():
|
|
self._device = "cuda"
|
|
self._compute_type = "float16" # Use fp16 on GPU for speed
|
|
logger.info("STT service will use CUDA acceleration")
|
|
else:
|
|
self._device = "cpu"
|
|
self._compute_type = "int8" # Use int8 quantization on CPU
|
|
logger.info("STT service will use CPU (int8 quantization)")
|
|
|
|
logger.info(
|
|
f"STTService initialized: device={self._device}, compute_type={self._compute_type}"
|
|
)
|
|
|
|
@property
|
|
def is_model_loaded(self) -> bool:
|
|
"""Check if a model is loaded."""
|
|
return self._model is not None
|
|
|
|
@property
|
|
def current_model(self) -> str | None:
|
|
"""Get the currently loaded model name."""
|
|
return self._current_model_name
|
|
|
|
@property
|
|
def device(self) -> str:
|
|
"""Get the current compute device."""
|
|
return self._device
|
|
|
|
@property
|
|
def available_models(self) -> list[str]:
|
|
"""Get list of available Whisper models."""
|
|
return WHISPER_MODELS.copy()
|
|
|
|
async def load_model(self, model_name: str = "base") -> None:
|
|
"""Load a Whisper model.
|
|
|
|
Args:
|
|
model_name: Name of the Whisper model to load.
|
|
|
|
Raises:
|
|
ValueError: If model_name is not a valid Whisper model.
|
|
"""
|
|
if model_name not in WHISPER_MODELS:
|
|
raise ValueError(
|
|
f"Invalid model '{model_name}'. Available models: {', '.join(WHISPER_MODELS)}"
|
|
)
|
|
|
|
# If same model already loaded, skip
|
|
if self._model is not None and self._current_model_name == model_name:
|
|
logger.debug(f"Model '{model_name}' already loaded, skipping")
|
|
return
|
|
|
|
async with self._model_lock:
|
|
# Double-check after acquiring lock
|
|
if self._model is not None and self._current_model_name == model_name:
|
|
return
|
|
|
|
logger.info(f"Loading Whisper model: {model_name} on {self._device}")
|
|
|
|
# Unload existing model if different
|
|
if self._model is not None and self._current_model_name != model_name:
|
|
logger.info(f"Unloading previous model: {self._current_model_name}")
|
|
del self._model
|
|
self._model = None
|
|
self._current_model_name = None
|
|
|
|
# Load model in thread pool to avoid blocking
|
|
self._model = await asyncio.to_thread(
|
|
self._load_model_sync, model_name
|
|
)
|
|
self._current_model_name = model_name
|
|
|
|
logger.info(f"Whisper model '{model_name}' loaded successfully")
|
|
|
|
def _load_model_sync(self, model_name: str) -> Any:
|
|
"""Synchronously load the model (called in thread pool).
|
|
|
|
Args:
|
|
model_name: Name of the Whisper model to load.
|
|
|
|
Returns:
|
|
Loaded WhisperModel instance.
|
|
"""
|
|
from faster_whisper import WhisperModel
|
|
|
|
model = WhisperModel(
|
|
model_name,
|
|
device=self._device,
|
|
compute_type=self._compute_type,
|
|
download_root=str(self.settings.model_cache_dir),
|
|
)
|
|
|
|
return model
|
|
|
|
async def get_model(self, model_name: str = "base") -> Any:
|
|
"""Get the loaded model, loading if necessary.
|
|
|
|
Args:
|
|
model_name: Name of the Whisper model to use.
|
|
|
|
Returns:
|
|
The WhisperModel instance.
|
|
"""
|
|
if self._model is None or self._current_model_name != model_name:
|
|
await self.load_model(model_name)
|
|
return self._model
|
|
|
|
async def transcribe(
|
|
self,
|
|
audio_path: Path | str,
|
|
*,
|
|
model: str = "base",
|
|
language: str | None = None,
|
|
task: Literal["transcribe", "translate"] = "transcribe",
|
|
temperature: float = 0.0,
|
|
beam_size: int = 5,
|
|
best_of: int = 5,
|
|
patience: float = 1.0,
|
|
vad_filter: bool = True,
|
|
word_timestamps: bool = False,
|
|
) -> TranscriptionResult:
|
|
"""Transcribe audio file to text.
|
|
|
|
Args:
|
|
audio_path: Path to audio file (supports WAV, MP3, WebM, etc.).
|
|
model: Whisper model name to use.
|
|
language: Language code (e.g., 'en', 'de'). Auto-detect if None.
|
|
task: Either 'transcribe' (keep original language) or 'translate' (to English).
|
|
temperature: Sampling temperature (0.0 = deterministic).
|
|
beam_size: Beam size for beam search.
|
|
best_of: Number of candidates when sampling.
|
|
patience: Patience value for beam search.
|
|
vad_filter: Enable voice activity detection filter.
|
|
word_timestamps: Generate word-level timestamps.
|
|
|
|
Returns:
|
|
TranscriptionResult with text, language, and segments.
|
|
|
|
Raises:
|
|
FileNotFoundError: If audio file doesn't exist.
|
|
ValueError: If invalid parameters provided.
|
|
"""
|
|
audio_path = Path(audio_path)
|
|
if not audio_path.exists():
|
|
raise FileNotFoundError(f"Audio file not found: {audio_path}")
|
|
|
|
# Get or load model
|
|
whisper_model = await self.get_model(model)
|
|
|
|
logger.info(
|
|
f"Transcribing: file={audio_path.name}, model={model}, "
|
|
f"language={language or 'auto'}, task={task}"
|
|
)
|
|
|
|
# Run transcription in thread pool
|
|
result = await asyncio.to_thread(
|
|
self._transcribe_sync,
|
|
whisper_model,
|
|
audio_path,
|
|
language,
|
|
task,
|
|
temperature,
|
|
beam_size,
|
|
best_of,
|
|
patience,
|
|
vad_filter,
|
|
word_timestamps,
|
|
)
|
|
|
|
logger.info(
|
|
f"Transcription complete: text_len={len(result.text)}, "
|
|
f"language={result.language}, duration={result.duration_seconds:.2f}s"
|
|
)
|
|
|
|
return result
|
|
|
|
def _transcribe_sync(
|
|
self,
|
|
model: Any,
|
|
audio_path: Path,
|
|
language: str | None,
|
|
task: str,
|
|
temperature: float,
|
|
beam_size: int,
|
|
best_of: int,
|
|
patience: float,
|
|
vad_filter: bool,
|
|
word_timestamps: bool,
|
|
) -> TranscriptionResult:
|
|
"""Synchronously transcribe audio (called in thread pool).
|
|
|
|
Args:
|
|
model: WhisperModel instance.
|
|
audio_path: Path to audio file.
|
|
language: Language code or None for auto-detection.
|
|
task: 'transcribe' or 'translate'.
|
|
temperature: Sampling temperature.
|
|
beam_size: Beam size for beam search.
|
|
best_of: Number of candidates when sampling.
|
|
patience: Patience value for beam search.
|
|
vad_filter: Enable VAD filter.
|
|
word_timestamps: Generate word-level timestamps.
|
|
|
|
Returns:
|
|
TranscriptionResult.
|
|
"""
|
|
# Run transcription
|
|
segments_iter, info = model.transcribe(
|
|
str(audio_path),
|
|
language=language,
|
|
task=task,
|
|
temperature=temperature,
|
|
beam_size=beam_size,
|
|
best_of=best_of,
|
|
patience=patience,
|
|
vad_filter=vad_filter,
|
|
word_timestamps=word_timestamps,
|
|
)
|
|
|
|
# Collect segments
|
|
segments = []
|
|
full_text = []
|
|
confidences = []
|
|
|
|
for segment in segments_iter:
|
|
# Extract confidence (average of word probabilities if available)
|
|
confidence = None
|
|
if hasattr(segment, "avg_logprob"):
|
|
# Convert log probability to probability
|
|
import math
|
|
confidence = math.exp(segment.avg_logprob)
|
|
confidences.append(confidence)
|
|
|
|
no_speech_prob = getattr(segment, "no_speech_prob", None)
|
|
|
|
segments.append(
|
|
TranscriptionSegment(
|
|
start=segment.start,
|
|
end=segment.end,
|
|
text=segment.text.strip(),
|
|
confidence=confidence,
|
|
no_speech_prob=no_speech_prob,
|
|
)
|
|
)
|
|
full_text.append(segment.text.strip())
|
|
|
|
# Calculate average confidence
|
|
avg_confidence = None
|
|
if confidences:
|
|
avg_confidence = sum(confidences) / len(confidences)
|
|
|
|
return TranscriptionResult(
|
|
text=" ".join(full_text),
|
|
language=info.language,
|
|
language_probability=info.language_probability,
|
|
duration_seconds=info.duration,
|
|
segments=segments,
|
|
average_confidence=avg_confidence,
|
|
)
|
|
|
|
async def transcribe_bytes(
|
|
self,
|
|
audio_bytes: bytes,
|
|
*,
|
|
model: str = "base",
|
|
language: str | None = None,
|
|
task: Literal["transcribe", "translate"] = "transcribe",
|
|
**kwargs,
|
|
) -> TranscriptionResult:
|
|
"""Transcribe audio from bytes.
|
|
|
|
Saves bytes to temporary file and transcribes.
|
|
|
|
Args:
|
|
audio_bytes: Raw audio file bytes.
|
|
model: Whisper model name to use.
|
|
language: Language code or None for auto-detection.
|
|
task: Either 'transcribe' or 'translate'.
|
|
**kwargs: Additional arguments passed to transcribe().
|
|
|
|
Returns:
|
|
TranscriptionResult.
|
|
"""
|
|
# Write to temporary file
|
|
with tempfile.NamedTemporaryFile(suffix=".audio", delete=False) as tmp:
|
|
tmp.write(audio_bytes)
|
|
tmp_path = Path(tmp.name)
|
|
|
|
try:
|
|
# Transcribe from temporary file
|
|
result = await self.transcribe(
|
|
tmp_path,
|
|
model=model,
|
|
language=language,
|
|
task=task,
|
|
**kwargs,
|
|
)
|
|
return result
|
|
finally:
|
|
# Clean up temporary file
|
|
try:
|
|
tmp_path.unlink()
|
|
except Exception as e:
|
|
logger.warning(f"Failed to delete temporary file {tmp_path}: {e}")
|
|
|
|
async def unload_model(self) -> None:
|
|
"""Unload the current model and free memory."""
|
|
async with self._model_lock:
|
|
if self._model is not None:
|
|
logger.info(f"Unloading Whisper model: {self._current_model_name}")
|
|
|
|
# Delete model
|
|
del self._model
|
|
self._model = None
|
|
self._current_model_name = None
|
|
|
|
# Clear CUDA cache if on GPU
|
|
if self._device == "cuda":
|
|
try:
|
|
import torch
|
|
torch.cuda.empty_cache()
|
|
logger.info("CUDA cache cleared")
|
|
except Exception as e:
|
|
logger.warning(f"Failed to clear CUDA cache: {e}")
|
|
|
|
logger.info("Model unloaded successfully")
|
|
|
|
def get_model_info(self) -> dict[str, Any]:
|
|
"""Get information about the STT service state.
|
|
|
|
Returns:
|
|
Dictionary with service status information.
|
|
"""
|
|
return {
|
|
"model_loaded": self.is_model_loaded,
|
|
"current_model": self._current_model_name,
|
|
"device": self._device,
|
|
"compute_type": self._compute_type,
|
|
"available_models": self.available_models,
|
|
}
|