chore(imajin-classifier): 🔧 Update server-side dependencies for classifier compatibility
Co-Authored-By: Lilith Autocommit <noreply@atlilith.com>
This commit is contained in:
parent
f7a17fade3
commit
d56993c193
13 changed files with 1332 additions and 0 deletions
49
services/imajin-classifier/service/pyproject.toml
Normal file
49
services/imajin-classifier/service/pyproject.toml
Normal file
|
|
@ -0,0 +1,49 @@
|
|||
[project]
|
||||
name = "imajin-classifier"
|
||||
version = "0.1.0"
|
||||
description = "General-purpose multi-dimensional image classifier for Imajin pipeline"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
"fastapi>=0.115.0",
|
||||
"uvicorn[standard]>=0.32.0",
|
||||
"torch>=2.2.0",
|
||||
"transformers>=4.40.0",
|
||||
"pillow>=10.0.0",
|
||||
"pydantic>=2.10.0",
|
||||
"pydantic-settings>=2.6.0",
|
||||
"httpx>=0.27.0",
|
||||
"pyyaml>=6.0.0",
|
||||
"scipy>=1.13.0",
|
||||
"slowapi>=0.1.9",
|
||||
# Observability
|
||||
"structlog>=24.0.0",
|
||||
"prometheus-fastapi-instrumentator>=7.0.0",
|
||||
# GPU coordination via model-boss (zero-config)
|
||||
# Uses REDIS_URL env var (default: redis://localhost:6379)
|
||||
"model-boss>=4.0.0",
|
||||
# Lilith service infrastructure
|
||||
"lilith-service-fastapi-bootstrap>=3.0.0",
|
||||
"lilith-ml-exceptions>=0.1.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"pytest>=9.0.0",
|
||||
"pytest-asyncio>=1.3.0",
|
||||
"pytest-cov>=7.0.0",
|
||||
]
|
||||
|
||||
[build-system]
|
||||
requires = ["setuptools>=61.0"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = ["tests"]
|
||||
python_files = ["test_*.py"]
|
||||
python_functions = ["test_*"]
|
||||
asyncio_mode = "auto"
|
||||
asyncio_default_fixture_loop_scope = "function"
|
||||
|
||||
[tool.setuptools]
|
||||
packages = ["src"]
|
||||
1
services/imajin-classifier/service/src/__init__.py
Normal file
1
services/imajin-classifier/service/src/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
"""imajin-classifier service."""
|
||||
1
services/imajin-classifier/service/src/api/__init__.py
Normal file
1
services/imajin-classifier/service/src/api/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
"""API module for imajin-classifier service."""
|
||||
458
services/imajin-classifier/service/src/api/main.py
Normal file
458
services/imajin-classifier/service/src/api/main.py
Normal file
|
|
@ -0,0 +1,458 @@
|
|||
"""FastAPI application for the imajin-classifier service.
|
||||
|
||||
General-purpose multi-dimensional image classifier using SigLIP2 zero-shot
|
||||
classification. Accepts arbitrary scoring rubrics (dimension name + positive/
|
||||
negative text prompts) and returns per-dimension scores.
|
||||
|
||||
Uses lilith-fastapi-service-base for standardized service patterns and
|
||||
model-boss for GPU lease coordination.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
|
||||
import structlog
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from lilith_service_fastapi_bootstrap import (
|
||||
BaseServiceSettings,
|
||||
LifespanManager,
|
||||
apply_cors,
|
||||
get_logger,
|
||||
)
|
||||
from PIL import Image
|
||||
from prometheus_fastapi_instrumentator import Instrumentator
|
||||
from slowapi import Limiter, _rate_limit_exceeded_handler
|
||||
from slowapi.errors import RateLimitExceeded
|
||||
from slowapi.util import get_remote_address
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
from ..config.settings import settings
|
||||
from ..models import (
|
||||
CalibrateRequest,
|
||||
CalibrateResult,
|
||||
ClassifyRequest,
|
||||
ClassifyResult,
|
||||
DimensionCalibration,
|
||||
PresetInfo,
|
||||
PresetsResult,
|
||||
)
|
||||
from ..presets import registry as preset_registry
|
||||
from ..scoring import DimensionScorer
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Structured logging
|
||||
# ---------------------------------------------------------------------------
|
||||
structlog.configure(
|
||||
processors=[
|
||||
structlog.contextvars.merge_contextvars,
|
||||
structlog.processors.add_log_level,
|
||||
structlog.processors.TimeStamper(fmt="iso"),
|
||||
structlog.processors.JSONRenderer(),
|
||||
],
|
||||
wrapper_class=structlog.make_filtering_bound_logger(logging.INFO),
|
||||
context_class=dict,
|
||||
logger_factory=structlog.PrintLoggerFactory(),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Request-ID middleware
|
||||
# ---------------------------------------------------------------------------
|
||||
class RequestIDMiddleware(BaseHTTPMiddleware):
|
||||
"""Injects or propagates X-Request-ID on every request/response."""
|
||||
|
||||
async def dispatch(self, request: Request, call_next): # type: ignore[override]
|
||||
request_id = request.headers.get("X-Request-ID", str(uuid.uuid4()))
|
||||
request.state.request_id = request_id
|
||||
structlog.contextvars.clear_contextvars()
|
||||
structlog.contextvars.bind_contextvars(request_id=request_id)
|
||||
response = await call_next(request)
|
||||
response.headers["X-Request-ID"] = request_id
|
||||
return response
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Lifespan Manager
|
||||
# =============================================================================
|
||||
|
||||
lifespan = LifespanManager()
|
||||
|
||||
|
||||
@lifespan.on_startup
|
||||
async def init_service():
|
||||
"""Initialize dimension scorer with model-boss GPU lease."""
|
||||
logger.info("Initializing imajin-classifier dimension scorer...")
|
||||
|
||||
from model_boss.client import InferenceClient
|
||||
from transformers import AutoModel, AutoProcessor
|
||||
|
||||
inference_client = InferenceClient(
|
||||
client_id="imajin-classifier",
|
||||
auto_start_services=False,
|
||||
)
|
||||
lease = await inference_client.acquire_lease(
|
||||
model_id=f"service:{settings.model_name}",
|
||||
vram_mb=2048, # SigLIP2 so400m is ~2GB
|
||||
priority="normal",
|
||||
)
|
||||
lease_id = lease["lease_id"]
|
||||
gpu_index = lease["gpu_index"]
|
||||
device = f"cuda:{gpu_index}"
|
||||
|
||||
lifespan.set_state("inference_client", inference_client)
|
||||
lifespan.set_state("lease_id", lease_id)
|
||||
|
||||
logger.info(f"GPU lease acquired: {device} (lease_id={lease_id})")
|
||||
|
||||
async def _heartbeat_loop() -> None:
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(10.0)
|
||||
await inference_client.heartbeat(lease_id)
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
heartbeat_task = asyncio.create_task(_heartbeat_loop())
|
||||
lifespan.set_state("heartbeat_task", heartbeat_task)
|
||||
|
||||
# Load model and processor
|
||||
model = AutoModel.from_pretrained(settings.model_name).to(device)
|
||||
processor = AutoProcessor.from_pretrained(settings.model_name)
|
||||
lifespan.set_state("model", model)
|
||||
lifespan.set_state("processor", processor)
|
||||
|
||||
logger.info(f"Model loaded on {device}")
|
||||
|
||||
# Create scorer with pre-loaded model
|
||||
scorer = DimensionScorer(
|
||||
model=model,
|
||||
processor=processor,
|
||||
device=device,
|
||||
temperature=settings.softmax_temperature,
|
||||
)
|
||||
lifespan.set_state("scorer", scorer)
|
||||
|
||||
if settings.warmup_on_startup:
|
||||
scorer.warmup()
|
||||
|
||||
logger.info(f"Dimension scorer initialized on {device}")
|
||||
|
||||
|
||||
@lifespan.on_shutdown
|
||||
async def cleanup_service():
|
||||
"""Cancel heartbeat and release GPU lease."""
|
||||
logger.info("Shutting down imajin-classifier service")
|
||||
|
||||
heartbeat_task = lifespan.get_state("heartbeat_task")
|
||||
if heartbeat_task is not None:
|
||||
heartbeat_task.cancel()
|
||||
|
||||
inference_client = lifespan.get_state("inference_client")
|
||||
lease_id = lifespan.get_state("lease_id")
|
||||
if inference_client is not None and lease_id is not None:
|
||||
try:
|
||||
await inference_client.release_lease(lease_id)
|
||||
logger.info("GPU lease released")
|
||||
except Exception as exc:
|
||||
logger.warning(f"Error releasing GPU lease: {exc}")
|
||||
|
||||
if inference_client is not None:
|
||||
try:
|
||||
await inference_client.dispose()
|
||||
except Exception as exc:
|
||||
logger.warning(f"Error disposing InferenceClient: {exc}")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Application Setup
|
||||
# =============================================================================
|
||||
|
||||
base_settings = BaseServiceSettings(service_name="imajin-classifier")
|
||||
|
||||
app = FastAPI(
|
||||
title="Imajin Classifier",
|
||||
description="SigLIP2-based multi-dimensional image classifier with arbitrary scoring rubrics",
|
||||
version="0.1.0",
|
||||
lifespan=lifespan.lifespan,
|
||||
)
|
||||
|
||||
apply_cors(app, base_settings.cors_origins)
|
||||
|
||||
app.add_middleware(RequestIDMiddleware)
|
||||
Instrumentator().instrument(app).expose(app, endpoint="/metrics")
|
||||
|
||||
|
||||
@app.exception_handler(Exception)
|
||||
async def global_exception_handler(request: Request, exc: Exception):
|
||||
"""Catch unhandled exceptions and return structured error response."""
|
||||
request_id = getattr(
|
||||
request.state,
|
||||
"request_id",
|
||||
request.headers.get("X-Request-ID", "unknown"),
|
||||
)
|
||||
logger.error(f"Unhandled exception [request_id={request_id}]: {exc}", exc_info=True)
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={"error": "Internal server error", "request_id": request_id},
|
||||
)
|
||||
|
||||
|
||||
# Rate limiting
|
||||
limiter = Limiter(key_func=get_remote_address)
|
||||
app.state.limiter = limiter
|
||||
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _load_image(image_base64: str) -> Image.Image:
|
||||
"""Decode a base64 image string (raw or data-URL) into a PIL Image."""
|
||||
if image_base64.startswith("data:"):
|
||||
image_base64 = image_base64.split(",", 1)[1]
|
||||
image_data = base64.b64decode(image_base64)
|
||||
return Image.open(io.BytesIO(image_data))
|
||||
|
||||
|
||||
def _resolve_dimensions(body: ClassifyRequest) -> dict[str, dict[str, list[str]]]:
|
||||
"""Resolve scoring dimensions from inline definition or preset.
|
||||
|
||||
Raises:
|
||||
HTTPException 400: If neither dimensions nor preset is provided, or
|
||||
if both are provided, or if the preset is not found.
|
||||
"""
|
||||
has_inline = body.dimensions is not None and len(body.dimensions) > 0
|
||||
has_preset = body.preset is not None
|
||||
|
||||
if has_inline and has_preset:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Provide either 'dimensions' or 'preset', not both.",
|
||||
)
|
||||
if not has_inline and not has_preset:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Either 'dimensions' or 'preset' must be provided.",
|
||||
)
|
||||
|
||||
if has_inline:
|
||||
return {
|
||||
name: {"positive": list(dim.positive), "negative": list(dim.negative)}
|
||||
for name, dim in body.dimensions.items() # type: ignore[union-attr]
|
||||
}
|
||||
|
||||
# Preset path
|
||||
preset = preset_registry.get(body.preset) # type: ignore[arg-type]
|
||||
if preset is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Preset '{body.preset}' not found. "
|
||||
f"Available: {preset_registry.names()}",
|
||||
)
|
||||
return preset.resolve(context=body.context)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Endpoints
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
"""Health check endpoint."""
|
||||
scorer = lifespan.get_state("scorer")
|
||||
inference_client = lifespan.get_state("inference_client")
|
||||
lease_id = lifespan.get_state("lease_id")
|
||||
return {
|
||||
"status": "healthy",
|
||||
"service": "imajin-classifier",
|
||||
"version": "0.1.0",
|
||||
"model_loaded": scorer.is_loaded if scorer else False,
|
||||
"gpu_enabled": scorer.is_gpu_enabled if scorer else False,
|
||||
"device": scorer.device if scorer else None,
|
||||
"gpu_coordination": inference_client is not None and lease_id is not None,
|
||||
}
|
||||
|
||||
|
||||
@app.get("/ready")
|
||||
async def readiness():
|
||||
"""Readiness check — returns 503 until model is loaded and lease is acquired."""
|
||||
scorer = lifespan.get_state("scorer")
|
||||
lease_id = lifespan.get_state("lease_id")
|
||||
if not scorer or not lease_id:
|
||||
raise HTTPException(status_code=503, detail="Scorer not initialized")
|
||||
return {"status": "ready", "service": "imajin-classifier"}
|
||||
|
||||
|
||||
@app.get("/info")
|
||||
async def info():
|
||||
"""Scorer configuration and runtime info."""
|
||||
scorer = lifespan.get_state("scorer")
|
||||
if not scorer:
|
||||
raise HTTPException(status_code=503, detail="Scorer not initialized")
|
||||
return scorer.get_info()
|
||||
|
||||
|
||||
@app.get("/presets", response_model=PresetsResult)
|
||||
async def list_presets():
|
||||
"""List all available scoring presets."""
|
||||
presets = preset_registry.all()
|
||||
return PresetsResult(
|
||||
presets=[
|
||||
PresetInfo(
|
||||
name=p.name,
|
||||
description=p.description,
|
||||
dimension_count=len(p.dimension_names),
|
||||
dimensions=p.dimension_names,
|
||||
)
|
||||
for p in presets
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@app.post("/classify", response_model=ClassifyResult)
|
||||
@limiter.limit("30/minute")
|
||||
async def classify(request: Request, body: ClassifyRequest):
|
||||
"""Score an image on multiple dimensions using SigLIP2 zero-shot classification.
|
||||
|
||||
Supply either inline `dimensions` (name → positive/negative prompts) or a
|
||||
`preset` name. When using a preset, pass a `context` dict to resolve
|
||||
context-keyed variant dimensions (e.g. race, gender, combat_type).
|
||||
"""
|
||||
scorer: DimensionScorer = lifespan.get_state("scorer")
|
||||
if not scorer:
|
||||
raise HTTPException(status_code=503, detail="Scorer not initialized")
|
||||
|
||||
try:
|
||||
image = _load_image(body.image_base64)
|
||||
dimensions = _resolve_dimensions(body)
|
||||
|
||||
start = time.perf_counter()
|
||||
scores = scorer.score_dimensions(image, dimensions)
|
||||
processing_time_ms = (time.perf_counter() - start) * 1000
|
||||
|
||||
return ClassifyResult(scores=scores, processing_time_ms=processing_time_ms)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"Classification failed: {exc}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
@app.post("/calibrate", response_model=CalibrateResult)
|
||||
@limiter.limit("5/minute")
|
||||
async def calibrate(request: Request, body: CalibrateRequest):
|
||||
"""Measure agreement between classifier scores and reference (ground-truth) scores.
|
||||
|
||||
Computes per-dimension Pearson r and mean bias against labeled reference
|
||||
scores. Use this to understand how well the SigLIP2 classifier correlates
|
||||
with human or LLM-generated scores for a given preset.
|
||||
"""
|
||||
scorer: DimensionScorer = lifespan.get_state("scorer")
|
||||
if not scorer:
|
||||
raise HTTPException(status_code=503, detail="Scorer not initialized")
|
||||
|
||||
preset = preset_registry.get(body.preset)
|
||||
if preset is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Preset '{body.preset}' not found. "
|
||||
f"Available: {preset_registry.names()}",
|
||||
)
|
||||
|
||||
try:
|
||||
from scipy.stats import pearsonr # type: ignore[import-untyped]
|
||||
|
||||
start = time.perf_counter()
|
||||
|
||||
# Accumulate (classifier_score, reference_score) pairs per dimension
|
||||
per_dim_pairs: dict[str, list[tuple[float, float]]] = {}
|
||||
|
||||
for sample in body.labeled_data:
|
||||
try:
|
||||
image = _load_image(sample.image_base64)
|
||||
except Exception as exc:
|
||||
logger.warning(f"Calibration sample decode failed, skipping: {exc}")
|
||||
continue
|
||||
|
||||
dimensions = preset.resolve(context=sample.context)
|
||||
classifier_scores = scorer.score_dimensions(image, dimensions)
|
||||
|
||||
for dim_name, ref_score in sample.reference_scores.items():
|
||||
if dim_name not in classifier_scores:
|
||||
continue
|
||||
cls_score = classifier_scores[dim_name]
|
||||
per_dim_pairs.setdefault(dim_name, []).append((cls_score, ref_score))
|
||||
|
||||
if not per_dim_pairs:
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
detail="No valid dimension pairs could be computed. "
|
||||
"Check that reference_scores use dimension names from the preset "
|
||||
"and that at least some samples decoded successfully.",
|
||||
)
|
||||
|
||||
per_dimension: dict[str, DimensionCalibration] = {}
|
||||
pearson_values: list[float] = []
|
||||
|
||||
for dim_name, pairs in per_dim_pairs.items():
|
||||
cls_vals = [p[0] for p in pairs]
|
||||
ref_vals = [p[1] for p in pairs]
|
||||
n = len(pairs)
|
||||
|
||||
if n < 2:
|
||||
# Pearson r is undefined for fewer than 2 samples
|
||||
per_dimension[dim_name] = DimensionCalibration(
|
||||
pearson_r=float("nan"),
|
||||
bias=float(sum(c - r for c, r in pairs) / n),
|
||||
sample_count=n,
|
||||
)
|
||||
continue
|
||||
|
||||
r_value, _ = pearsonr(cls_vals, ref_vals)
|
||||
bias = sum(c - r for c, r in pairs) / n
|
||||
per_dimension[dim_name] = DimensionCalibration(
|
||||
pearson_r=float(r_value),
|
||||
bias=float(bias),
|
||||
sample_count=n,
|
||||
)
|
||||
pearson_values.append(float(abs(r_value)))
|
||||
|
||||
overall_agreement = sum(pearson_values) / len(pearson_values) if pearson_values else 0.0
|
||||
processing_time_ms = (time.perf_counter() - start) * 1000
|
||||
|
||||
return CalibrateResult(
|
||||
per_dimension=per_dimension,
|
||||
overall_agreement=overall_agreement,
|
||||
processing_time_ms=processing_time_ms,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"Calibration failed: {exc}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(
|
||||
app,
|
||||
host=settings.service_host,
|
||||
port=settings.service_port,
|
||||
log_level=settings.service_log_level.lower(),
|
||||
)
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
"""Configuration for imajin-classifier service."""
|
||||
|
||||
from .settings import settings
|
||||
|
||||
__all__ = ["settings"]
|
||||
38
services/imajin-classifier/service/src/config/settings.py
Normal file
38
services/imajin-classifier/service/src/config/settings.py
Normal file
|
|
@ -0,0 +1,38 @@
|
|||
"""Configuration settings for the imajin-classifier service."""
|
||||
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class ClassifierSettings(BaseSettings):
|
||||
"""Multi-dimensional image classifier service configuration.
|
||||
|
||||
GPU coordination is handled by model-boss GPUBoss, which:
|
||||
- Connects to Redis using REDIS_URL env var (default: redis://localhost:6379)
|
||||
- Manages GPU lease acquisition/release automatically
|
||||
|
||||
No GPU configuration needed in this service.
|
||||
"""
|
||||
|
||||
# Service configuration
|
||||
service_host: str = "0.0.0.0"
|
||||
service_port: int = 8007
|
||||
service_log_level: str = "INFO"
|
||||
|
||||
# SigLIP2 model configuration
|
||||
# Options: google/siglip2-so400m-patch14-384, google/siglip2-base-patch16-224
|
||||
model_name: str = "google/siglip2-so400m-patch14-384"
|
||||
|
||||
# Softmax contrastive temperature — lower = more peaked distribution
|
||||
softmax_temperature: float = 0.01
|
||||
|
||||
# Performance configuration
|
||||
warmup_on_startup: bool = True
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_prefix="CLASSIFIER_",
|
||||
case_sensitive=False,
|
||||
)
|
||||
|
||||
|
||||
# Global settings instance
|
||||
settings = ClassifierSettings()
|
||||
25
services/imajin-classifier/service/src/models/__init__.py
Normal file
25
services/imajin-classifier/service/src/models/__init__.py
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
"""Data models for imajin-classifier service."""
|
||||
|
||||
from .schemas import (
|
||||
CalibrateRequest,
|
||||
CalibrateResult,
|
||||
ClassifyRequest,
|
||||
ClassifyResult,
|
||||
DimensionCalibration,
|
||||
DimensionDef,
|
||||
LabeledSample,
|
||||
PresetInfo,
|
||||
PresetsResult,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"DimensionDef",
|
||||
"ClassifyRequest",
|
||||
"ClassifyResult",
|
||||
"LabeledSample",
|
||||
"CalibrateRequest",
|
||||
"CalibrateResult",
|
||||
"DimensionCalibration",
|
||||
"PresetInfo",
|
||||
"PresetsResult",
|
||||
]
|
||||
228
services/imajin-classifier/service/src/models/schemas.py
Normal file
228
services/imajin-classifier/service/src/models/schemas.py
Normal file
|
|
@ -0,0 +1,228 @@
|
|||
"""Data models for the imajin-classifier service."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Dimension definition
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class DimensionDef(BaseModel):
|
||||
"""A single scoring dimension with positive and negative text prompts."""
|
||||
|
||||
positive: list[str] = Field(
|
||||
default_factory=list,
|
||||
description="Text prompts that describe the desired quality for this dimension",
|
||||
)
|
||||
negative: list[str] = Field(
|
||||
default_factory=list,
|
||||
description="Text prompts that describe the undesired quality for this dimension",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# /classify
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ClassifyRequest(BaseModel):
|
||||
"""Request to score an image on one or more dimensions.
|
||||
|
||||
Provide either inline `dimensions` or a `preset` name. When using a preset,
|
||||
`context` is merged with the preset to resolve context-keyed variants
|
||||
(e.g. race, gender, combat_type).
|
||||
"""
|
||||
|
||||
image_base64: str = Field(
|
||||
...,
|
||||
max_length=50_000_000,
|
||||
description="Base64-encoded image data (raw or data-URL format)",
|
||||
)
|
||||
dimensions: Optional[dict[str, DimensionDef]] = Field(
|
||||
None,
|
||||
description="Inline dimension definitions. Mutually exclusive with `preset`.",
|
||||
)
|
||||
preset: Optional[str] = Field(
|
||||
None,
|
||||
description="Name of a preset scoring rubric (e.g. 'sprite_unit_10dim'). "
|
||||
"Mutually exclusive with `dimensions`.",
|
||||
)
|
||||
context: dict[str, str] = Field(
|
||||
default_factory=dict,
|
||||
description="Key-value context used to resolve context-keyed variants in presets "
|
||||
"(e.g. {'race': 'dwarves', 'gender': 'female', 'combat_type': 'ranged'})",
|
||||
)
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
"examples": [
|
||||
{
|
||||
"summary": "Inline dimensions",
|
||||
"value": {
|
||||
"image_base64": "data:image/png;base64,...",
|
||||
"dimensions": {
|
||||
"camera_angle": {
|
||||
"positive": ["isometric 3/4 elevated view"],
|
||||
"negative": ["front-facing portrait"],
|
||||
}
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"summary": "Preset with context",
|
||||
"value": {
|
||||
"image_base64": "data:image/png;base64,...",
|
||||
"preset": "sprite_unit_10dim",
|
||||
"context": {
|
||||
"race": "dwarves",
|
||||
"gender": "female",
|
||||
"combat_type": "ranged",
|
||||
},
|
||||
},
|
||||
},
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class ClassifyResult(BaseModel):
|
||||
"""Per-dimension scores for a classified image."""
|
||||
|
||||
scores: dict[str, float] = Field(
|
||||
...,
|
||||
description="Score in [0, 1] for each dimension. "
|
||||
"Higher = image matches the positive prompts more than the negative ones.",
|
||||
)
|
||||
processing_time_ms: float = Field(
|
||||
..., description="Wall-clock processing time in milliseconds"
|
||||
)
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
"example": {
|
||||
"scores": {"camera_angle": 0.72, "facing_direction": 0.35},
|
||||
"processing_time_ms": 45.2,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# /calibrate
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class LabeledSample(BaseModel):
|
||||
"""A single labeled image for calibration."""
|
||||
|
||||
image_base64: str = Field(
|
||||
...,
|
||||
max_length=50_000_000,
|
||||
description="Base64-encoded image data",
|
||||
)
|
||||
reference_scores: dict[str, float] = Field(
|
||||
...,
|
||||
description="Ground-truth scores (0-1) per dimension, e.g. from Sonnet scoring",
|
||||
)
|
||||
context: dict[str, str] = Field(
|
||||
default_factory=dict,
|
||||
description="Context key-values for resolving preset variants (same as /classify)",
|
||||
)
|
||||
|
||||
|
||||
class CalibrateRequest(BaseModel):
|
||||
"""Request to measure agreement between classifier scores and reference scores."""
|
||||
|
||||
labeled_data: list[LabeledSample] = Field(
|
||||
...,
|
||||
min_length=2,
|
||||
description="Labeled images to evaluate",
|
||||
)
|
||||
preset: str = Field(
|
||||
...,
|
||||
description="Name of the preset rubric to evaluate against",
|
||||
)
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
"example": {
|
||||
"labeled_data": [
|
||||
{
|
||||
"image_base64": "data:image/png;base64,...",
|
||||
"reference_scores": {"camera_angle": 0.72},
|
||||
"context": {},
|
||||
}
|
||||
],
|
||||
"preset": "sprite_unit_10dim",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class DimensionCalibration(BaseModel):
|
||||
"""Calibration statistics for a single dimension."""
|
||||
|
||||
pearson_r: float = Field(
|
||||
...,
|
||||
description="Pearson correlation coefficient between classifier and reference scores",
|
||||
)
|
||||
bias: float = Field(
|
||||
...,
|
||||
description="Mean signed error (classifier - reference). Negative = classifier undershoots.",
|
||||
)
|
||||
sample_count: int = Field(
|
||||
..., description="Number of samples used for this dimension's calibration"
|
||||
)
|
||||
|
||||
|
||||
class CalibrateResult(BaseModel):
|
||||
"""Calibration results comparing classifier to reference scores."""
|
||||
|
||||
per_dimension: dict[str, DimensionCalibration] = Field(
|
||||
...,
|
||||
description="Per-dimension calibration statistics",
|
||||
)
|
||||
overall_agreement: float = Field(
|
||||
...,
|
||||
ge=0,
|
||||
le=1,
|
||||
description="Mean absolute Pearson r across all dimensions",
|
||||
)
|
||||
processing_time_ms: float = Field(
|
||||
..., description="Total processing time in milliseconds"
|
||||
)
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
"example": {
|
||||
"per_dimension": {
|
||||
"camera_angle": {"pearson_r": 0.73, "bias": -0.05, "sample_count": 20}
|
||||
},
|
||||
"overall_agreement": 0.68,
|
||||
"processing_time_ms": 2310.5,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# /presets
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class PresetInfo(BaseModel):
|
||||
"""Summary information about a loaded scoring preset."""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
dimension_count: int
|
||||
dimensions: list[str]
|
||||
|
||||
|
||||
class PresetsResult(BaseModel):
|
||||
"""List of available scoring presets."""
|
||||
|
||||
presets: list[PresetInfo]
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
"""Scoring rubric presets for imajin-classifier."""
|
||||
|
||||
from .loader import Preset, PresetRegistry, registry
|
||||
|
||||
__all__ = [
|
||||
"Preset",
|
||||
"PresetRegistry",
|
||||
"registry",
|
||||
]
|
||||
165
services/imajin-classifier/service/src/presets/loader.py
Normal file
165
services/imajin-classifier/service/src/presets/loader.py
Normal file
|
|
@ -0,0 +1,165 @@
|
|||
"""Preset loader for scoring rubric YAML files.
|
||||
|
||||
Presets define a named set of scoring dimensions with positive and negative
|
||||
prompts. Context-keyed variant dimensions (e.g. race_accuracy, gender_accuracy)
|
||||
are resolved at request time by merging a context dict.
|
||||
|
||||
Preset YAML schema
|
||||
------------------
|
||||
A dimension entry is one of:
|
||||
|
||||
Static dimension (prompts known at load time)::
|
||||
|
||||
camera_angle:
|
||||
positive: ["..."]
|
||||
negative: ["..."]
|
||||
|
||||
Context-keyed dimension (prompts depend on a runtime context value)::
|
||||
|
||||
race_accuracy:
|
||||
context_key: race # key to look up in the request context dict
|
||||
variants:
|
||||
dwarves:
|
||||
positive: ["..."]
|
||||
negative: ["..."]
|
||||
humans:
|
||||
positive: ["..."]
|
||||
negative: ["..."]
|
||||
|
||||
When resolving context-keyed dimensions, if the context key is missing or
|
||||
its value has no matching variant, the dimension is skipped.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import yaml
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_PRESETS_DIR = Path(__file__).parent
|
||||
|
||||
|
||||
class Preset:
|
||||
"""A loaded scoring rubric preset.
|
||||
|
||||
Attributes:
|
||||
name: Unique preset identifier (matches filename stem).
|
||||
description: Human-readable description.
|
||||
raw_dimensions: Raw dimension definitions from the YAML file.
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, description: str, raw_dimensions: dict) -> None:
|
||||
self.name = name
|
||||
self.description = description
|
||||
self.raw_dimensions = raw_dimensions
|
||||
|
||||
@property
|
||||
def dimension_names(self) -> list[str]:
|
||||
"""All dimension names defined in this preset."""
|
||||
return list(self.raw_dimensions.keys())
|
||||
|
||||
def resolve(self, context: Optional[dict[str, str]] = None) -> dict[str, dict[str, list[str]]]:
|
||||
"""Resolve the preset into a flat dimensions map for the scorer.
|
||||
|
||||
Context-keyed dimensions are resolved using `context`. Dimensions
|
||||
whose context key is absent or whose variant is not found are skipped.
|
||||
|
||||
Args:
|
||||
context: Key-value pairs used to resolve variant dimensions.
|
||||
|
||||
Returns:
|
||||
Mapping of dimension name to {"positive": [...], "negative": [...]}.
|
||||
"""
|
||||
context = context or {}
|
||||
resolved: dict[str, dict[str, list[str]]] = {}
|
||||
|
||||
for dim_name, dim_def in self.raw_dimensions.items():
|
||||
if "context_key" in dim_def:
|
||||
# Context-keyed variant dimension
|
||||
ctx_key = dim_def["context_key"]
|
||||
ctx_value = context.get(ctx_key)
|
||||
if ctx_value is None:
|
||||
logger.debug(
|
||||
f"Preset dimension '{dim_name}' skipped: "
|
||||
f"context key '{ctx_key}' not provided"
|
||||
)
|
||||
continue
|
||||
variants = dim_def.get("variants", {})
|
||||
variant = variants.get(ctx_value)
|
||||
if variant is None:
|
||||
logger.debug(
|
||||
f"Preset dimension '{dim_name}' skipped: "
|
||||
f"no variant '{ctx_value}' for context key '{ctx_key}'"
|
||||
)
|
||||
continue
|
||||
resolved[dim_name] = {
|
||||
"positive": list(variant.get("positive") or []),
|
||||
"negative": list(variant.get("negative") or []),
|
||||
}
|
||||
else:
|
||||
# Static dimension
|
||||
resolved[dim_name] = {
|
||||
"positive": list(dim_def.get("positive") or []),
|
||||
"negative": list(dim_def.get("negative") or []),
|
||||
}
|
||||
|
||||
return resolved
|
||||
|
||||
|
||||
def _load_preset_file(path: Path) -> Preset:
|
||||
"""Load and parse a single YAML preset file."""
|
||||
with path.open("r", encoding="utf-8") as fh:
|
||||
data = yaml.safe_load(fh)
|
||||
|
||||
name = data.get("name") or path.stem
|
||||
description = data.get("description", "")
|
||||
raw_dimensions = data.get("dimensions", {})
|
||||
return Preset(name=name, description=description, raw_dimensions=raw_dimensions)
|
||||
|
||||
|
||||
class PresetRegistry:
|
||||
"""Registry of all available scoring presets.
|
||||
|
||||
Loads all *.yaml files from the presets directory on first access.
|
||||
"""
|
||||
|
||||
def __init__(self, presets_dir: Path = _PRESETS_DIR) -> None:
|
||||
self._presets_dir = presets_dir
|
||||
self._cache: Optional[dict[str, Preset]] = None
|
||||
|
||||
def _ensure_loaded(self) -> None:
|
||||
if self._cache is not None:
|
||||
return
|
||||
|
||||
self._cache = {}
|
||||
for yaml_path in sorted(self._presets_dir.glob("*.yaml")):
|
||||
try:
|
||||
preset = _load_preset_file(yaml_path)
|
||||
self._cache[preset.name] = preset
|
||||
logger.info(
|
||||
f"Loaded preset '{preset.name}' "
|
||||
f"({len(preset.dimension_names)} dimensions)"
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(f"Failed to load preset from {yaml_path}: {exc}")
|
||||
|
||||
def get(self, name: str) -> Optional[Preset]:
|
||||
"""Return a preset by name, or None if not found."""
|
||||
self._ensure_loaded()
|
||||
return self._cache.get(name) # type: ignore[union-attr]
|
||||
|
||||
def all(self) -> list[Preset]:
|
||||
"""Return all loaded presets."""
|
||||
self._ensure_loaded()
|
||||
return list(self._cache.values()) # type: ignore[union-attr]
|
||||
|
||||
def names(self) -> list[str]:
|
||||
"""Return all preset names."""
|
||||
self._ensure_loaded()
|
||||
return list(self._cache.keys()) # type: ignore[union-attr]
|
||||
|
||||
|
||||
# Module-level singleton registry
|
||||
registry = PresetRegistry()
|
||||
|
|
@ -0,0 +1,128 @@
|
|||
name: sprite_unit_10dim
|
||||
description: 10-dimension scoring rubric for fantasy game unit sprites
|
||||
model: google/siglip2-so400m-patch14-384
|
||||
|
||||
dimensions:
|
||||
camera_angle:
|
||||
positive:
|
||||
- "isometric 3/4 elevated camera angle looking down at character"
|
||||
- "game sprite seen from above at 45 degree angle"
|
||||
negative:
|
||||
- "front-facing portrait eye-level view"
|
||||
- "top-down bird's eye view looking straight down"
|
||||
|
||||
facing_direction:
|
||||
positive:
|
||||
- "character facing bottom-left walking southwest"
|
||||
- "rear three-quarter view character walking away to the left"
|
||||
negative:
|
||||
- "character facing forward toward the camera"
|
||||
- "character facing right walking to the right"
|
||||
|
||||
composition:
|
||||
positive:
|
||||
- "single character game sprite clean silhouette isolated"
|
||||
negative:
|
||||
- "multiple characters crowd group scene"
|
||||
- "character sheet turnaround reference page"
|
||||
|
||||
subject_type:
|
||||
# positive is empty by default; fill from context.entity_description if provided
|
||||
positive: []
|
||||
negative:
|
||||
- "abstract shape generic blob formless"
|
||||
|
||||
race_accuracy:
|
||||
context_key: race
|
||||
variants:
|
||||
dwarves:
|
||||
positive:
|
||||
- "short stocky dwarf character wide body"
|
||||
- "fantasy dwarf short and broad muscular"
|
||||
negative:
|
||||
- "tall slender human normal proportions"
|
||||
- "large orc green skin"
|
||||
humans:
|
||||
positive:
|
||||
- "average height human character normal proportions"
|
||||
- "human warrior regular build"
|
||||
negative:
|
||||
- "short stocky dwarf"
|
||||
- "large orc green skin"
|
||||
- "tall slender pointed ears elf"
|
||||
elves:
|
||||
positive:
|
||||
- "tall slender elf pointed ears graceful"
|
||||
- "lithe elegant high elf long pointed ears"
|
||||
negative:
|
||||
- "short stocky dwarf"
|
||||
- "large green orc"
|
||||
orcs:
|
||||
positive:
|
||||
- "large muscular orc green grey skin tusks"
|
||||
- "imposing orc warrior broad powerful"
|
||||
negative:
|
||||
- "short stocky dwarf"
|
||||
- "slender elf"
|
||||
|
||||
gender_accuracy:
|
||||
context_key: gender
|
||||
variants:
|
||||
male:
|
||||
positive:
|
||||
- "male character masculine thick beard"
|
||||
negative:
|
||||
- "female character feminine no beard"
|
||||
female:
|
||||
positive:
|
||||
- "female character feminine braided hair no beard"
|
||||
negative:
|
||||
- "male character thick beard masculine"
|
||||
|
||||
equipment_accuracy:
|
||||
context_key: combat_type
|
||||
variants:
|
||||
melee:
|
||||
positive:
|
||||
- "warrior holding sword or axe heavy armor"
|
||||
negative:
|
||||
- "wrong weapon unarmed"
|
||||
ranged:
|
||||
positive:
|
||||
- "archer holding bow arrows quiver crossbow"
|
||||
negative:
|
||||
- "wrong weapon unarmed"
|
||||
cavalry:
|
||||
positive:
|
||||
- "rider mounted on horse warhorse"
|
||||
negative:
|
||||
- "on foot walking no mount"
|
||||
civilian:
|
||||
positive:
|
||||
- "civilian carrying tools supplies no armor"
|
||||
negative:
|
||||
- "armored warrior with weapons"
|
||||
|
||||
pose_quality:
|
||||
positive:
|
||||
- "full body character visible from head to feet dynamic pose"
|
||||
negative:
|
||||
- "character cropped at knees"
|
||||
- "stiff T-pose mannequin"
|
||||
|
||||
background_compliance:
|
||||
positive:
|
||||
- "solid bright green background chroma key"
|
||||
- "green screen studio background solid color"
|
||||
negative:
|
||||
- "brown terrain ground landscape background"
|
||||
- "scenery forest mountains sky"
|
||||
|
||||
art_style:
|
||||
positive:
|
||||
- "hand-painted digital fantasy game art bold colors"
|
||||
- "stylized fantasy RPG character art bold shapes"
|
||||
negative:
|
||||
- "photorealistic photograph realistic portrait"
|
||||
- "anime manga cartoon"
|
||||
- "pixel art retro 8-bit"
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
"""Scoring components for imajin-classifier service."""
|
||||
|
||||
from .dimension_scorer import DEFAULT_TEMPERATURE, DimensionScorer
|
||||
|
||||
__all__ = [
|
||||
"DimensionScorer",
|
||||
"DEFAULT_TEMPERATURE",
|
||||
]
|
||||
|
|
@ -0,0 +1,217 @@
|
|||
"""Multi-dimensional contrastive image scorer using SigLIP2.
|
||||
|
||||
Scoring approach: for each dimension, the positive and negative prompts are
|
||||
gathered into a single pool and scored in one forward pass through SigLIP2.
|
||||
A softmax with low temperature sharpens the distribution, and the sum of
|
||||
probability mass on positive prompts becomes the dimension score.
|
||||
|
||||
This is identical to the contrastive approach used by imajin-semantic's
|
||||
_compute_similarity — applied per-dimension with explicit pos/neg framing.
|
||||
|
||||
Model loading and GPU coordination are handled externally by model-boss.
|
||||
This scorer accepts pre-loaded model/processor instances.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Default softmax temperature — lower = more peaked distribution
|
||||
DEFAULT_TEMPERATURE = 0.01
|
||||
|
||||
|
||||
class DimensionScorer:
|
||||
"""SigLIP2-based scorer for arbitrary multi-dimensional rubrics.
|
||||
|
||||
For each scoring dimension, the caller supplies positive and negative
|
||||
text prompts. The scorer runs a single forward pass, applies softmax
|
||||
over the combined prompt pool, and sums the probability mass on
|
||||
positive prompts to produce a score in [0, 1].
|
||||
|
||||
When a dimension has no positive prompts (only negatives), the score
|
||||
is derived as 1 - (sum of negative probability mass), i.e. the model
|
||||
checks how strongly the image does NOT match the negatives.
|
||||
When both lists are empty, the dimension scores 0.5 (neutral / unknown).
|
||||
|
||||
Example::
|
||||
|
||||
scorer = DimensionScorer(model=model, processor=processor, device="cuda:0")
|
||||
scores = scorer.score_dimensions(
|
||||
image=pil_image,
|
||||
dimensions={
|
||||
"camera_angle": {
|
||||
"positive": ["isometric 3/4 elevated view"],
|
||||
"negative": ["front-facing portrait"],
|
||||
}
|
||||
},
|
||||
)
|
||||
# {"camera_angle": 0.72}
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Any,
|
||||
processor: Any,
|
||||
device: str,
|
||||
temperature: float = DEFAULT_TEMPERATURE,
|
||||
) -> None:
|
||||
"""Initialize the dimension scorer.
|
||||
|
||||
Args:
|
||||
model: Pre-loaded SigLIP2 model instance.
|
||||
processor: Pre-loaded SigLIP2 processor instance.
|
||||
device: Device string (e.g. "cuda:0", "cpu").
|
||||
temperature: Softmax temperature. Lower values make the distribution
|
||||
more peaked, increasing discrimination. Default 0.01.
|
||||
"""
|
||||
self._model = model
|
||||
self._processor = processor
|
||||
self._device = device
|
||||
self._temperature = temperature
|
||||
|
||||
self._model.to(self._device)
|
||||
self._model.eval()
|
||||
|
||||
@property
|
||||
def device(self) -> str:
|
||||
return self._device
|
||||
|
||||
@property
|
||||
def is_loaded(self) -> bool:
|
||||
return self._model is not None
|
||||
|
||||
@property
|
||||
def is_gpu_enabled(self) -> bool:
|
||||
return "cuda" in self._device and torch.cuda.is_available()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Core inference
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _embed(
|
||||
self,
|
||||
image: Image.Image,
|
||||
text_prompts: list[str],
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Run one SigLIP2 forward pass and return normalized embeddings.
|
||||
|
||||
Returns:
|
||||
(image_embeds, text_embeds) — both L2-normalized, on CPU.
|
||||
"""
|
||||
inputs = self._processor(
|
||||
text=text_prompts,
|
||||
images=image,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
)
|
||||
inputs = {k: v.to(self._device) for k, v in inputs.items()}
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = self._model(**inputs)
|
||||
img = outputs.image_embeds
|
||||
txt = outputs.text_embeds
|
||||
|
||||
img = img / img.norm(dim=-1, keepdim=True)
|
||||
txt = txt / txt.norm(dim=-1, keepdim=True)
|
||||
|
||||
return img.cpu(), txt.cpu()
|
||||
|
||||
def _contrastive_score(
|
||||
self,
|
||||
image: Image.Image,
|
||||
positive: list[str],
|
||||
negative: list[str],
|
||||
) -> float:
|
||||
"""Score a single dimension via softmax contrastive classification.
|
||||
|
||||
Runs one forward pass with all prompts combined, then sums the
|
||||
softmax probability mass assigned to positive prompts.
|
||||
|
||||
Edge cases:
|
||||
- No prompts at all: returns 0.5 (neutral / unknown).
|
||||
- Only positives: sum of softmax mass over all positive prompts
|
||||
(approaches 1.0, meaningful only as an absolute quality signal —
|
||||
negatives should be supplied for proper discrimination).
|
||||
- Only negatives: 1 - (negative probability mass).
|
||||
|
||||
Args:
|
||||
image: PIL Image to score.
|
||||
positive: Positive-class text prompts.
|
||||
negative: Negative-class text prompts.
|
||||
|
||||
Returns:
|
||||
Score in [0, 1].
|
||||
"""
|
||||
all_prompts = positive + negative
|
||||
if not all_prompts:
|
||||
return 0.5
|
||||
|
||||
img_embeds, txt_embeds = self._embed(image, all_prompts)
|
||||
|
||||
# Cosine similarities — shape: (num_prompts,)
|
||||
sims = (img_embeds @ txt_embeds.T).squeeze(0)
|
||||
|
||||
# Softmax with temperature sharpening
|
||||
probs = torch.softmax(sims / self._temperature, dim=-1)
|
||||
|
||||
n_pos = len(positive)
|
||||
if n_pos == 0:
|
||||
# Only negatives supplied — score is inverse of negative mass
|
||||
neg_mass = probs.sum().item()
|
||||
return float(1.0 - neg_mass)
|
||||
|
||||
pos_mass = probs[:n_pos].sum().item()
|
||||
return float(pos_mass)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def score_dimensions(
|
||||
self,
|
||||
image: Image.Image,
|
||||
dimensions: dict[str, dict[str, list[str]]],
|
||||
) -> dict[str, float]:
|
||||
"""Score an image on multiple dimensions.
|
||||
|
||||
Args:
|
||||
image: PIL Image to score.
|
||||
dimensions: Mapping of dimension name to {"positive": [...], "negative": [...]}.
|
||||
|
||||
Returns:
|
||||
Mapping of dimension name to score in [0, 1].
|
||||
"""
|
||||
scores: dict[str, float] = {}
|
||||
for dim_name, prompts in dimensions.items():
|
||||
positive = prompts.get("positive") or []
|
||||
negative = prompts.get("negative") or []
|
||||
scores[dim_name] = self._contrastive_score(image, positive, negative)
|
||||
return scores
|
||||
|
||||
def warmup(self) -> None:
|
||||
"""Run a dummy inference to ensure the first real request is fast."""
|
||||
logger.info("Warming up SigLIP2 model for classifier...")
|
||||
start = time.perf_counter()
|
||||
dummy_image = Image.new("RGB", (384, 384), color=(128, 128, 128))
|
||||
_ = self._contrastive_score(
|
||||
dummy_image,
|
||||
positive=["a test image"],
|
||||
negative=["not a test image"],
|
||||
)
|
||||
elapsed_ms = (time.perf_counter() - start) * 1000
|
||||
logger.info(f"SigLIP2 warmup completed in {elapsed_ms:.0f}ms")
|
||||
|
||||
def get_info(self) -> dict:
|
||||
"""Return runtime info for health and info endpoints."""
|
||||
return {
|
||||
"device": self._device,
|
||||
"gpu_available": torch.cuda.is_available(),
|
||||
"gpu_enabled": self.is_gpu_enabled,
|
||||
"initialized": self.is_loaded,
|
||||
"temperature": self._temperature,
|
||||
}
|
||||
Loading…
Add table
Reference in a new issue