feat(imajin-classifier): ✨ Update scoring thresholds and add new evaluation metrics in ClaudeScorer for improved Claude model output evaluation
Co-Authored-By: Lilith Autocommit <noreply@atlilith.com>
This commit is contained in:
parent
d8e31d0d9c
commit
fed3629dbe
1 changed files with 230 additions and 0 deletions
230
services/imajin-classifier/service/src/scoring/claude_scorer.py
Normal file
230
services/imajin-classifier/service/src/scoring/claude_scorer.py
Normal file
|
|
@ -0,0 +1,230 @@
|
|||
"""Claude scorer using claude-code-batch-sdk (CLI-based, no API key needed).
|
||||
|
||||
Uses the ClaudeClient from claude-code-batch-sdk which wraps the `claude` CLI
|
||||
for async data generation. Supports model selection (haiku, sonnet, opus) and
|
||||
concurrent request limiting via semaphore.
|
||||
|
||||
The Claude CLI handles authentication — no ANTHROPIC_API_KEY needed.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# claude-code-batch-sdk lives in @applications/@ml/@packages/@py/
|
||||
# Navigate up from this file to @applications, then into @ml
|
||||
def _find_sdk_path() -> Path:
|
||||
p = Path(__file__).resolve()
|
||||
while p.name != "@applications" and p != p.parent:
|
||||
p = p.parent
|
||||
return p / "@ml/@packages/@py/claude-code-batch-sdk/src"
|
||||
|
||||
_SDK_PATH = _find_sdk_path()
|
||||
|
||||
|
||||
def _ensure_sdk() -> None:
|
||||
"""Add claude-code-batch-sdk to sys.path if not already importable."""
|
||||
if _SDK_PATH.exists() and str(_SDK_PATH) not in sys.path:
|
||||
sys.path.insert(0, str(_SDK_PATH))
|
||||
|
||||
|
||||
MODEL_MAP: dict[str, str] = {
|
||||
"claude-haiku": "haiku",
|
||||
"claude-sonnet": "sonnet",
|
||||
"claude-opus": "opus",
|
||||
}
|
||||
|
||||
SYSTEM_PROMPT = (
|
||||
"You are a ruthless art director for a commercial fantasy 4X strategy game "
|
||||
"(Civilization V + Master of Magic). You evaluate AI-generated sprites with "
|
||||
'the standard: "Would I ship this in a $30 game?"\n\n'
|
||||
"You score HARSHLY. Most AI-generated game art is mediocre. "
|
||||
"A 0.5 means average AI output — not good enough. "
|
||||
"A 0.7+ means genuinely usable in production. "
|
||||
"A 0.9+ means exceptional, no notes.\n\n"
|
||||
"Always respond with valid JSON only — no other text."
|
||||
)
|
||||
|
||||
|
||||
def _build_scoring_prompt(
|
||||
dimensions: dict[str, dict[str, list[str]]],
|
||||
entity_description: str = "",
|
||||
image_filename: str = "",
|
||||
) -> str:
|
||||
"""Build the scoring prompt from dimension definitions."""
|
||||
lines = []
|
||||
|
||||
if image_filename:
|
||||
lines.append(f"Look at the image file {image_filename} in this directory.")
|
||||
lines.append("")
|
||||
|
||||
lines.append("Score this image on the following dimensions (0.0 to 1.0):\n")
|
||||
|
||||
for i, (dim_name, dim_def) in enumerate(dimensions.items(), 1):
|
||||
pos = dim_def.get("positive", [])
|
||||
neg = dim_def.get("negative", [])
|
||||
pos_str = "; ".join(pos) if pos else "general quality"
|
||||
neg_str = "; ".join(neg) if neg else "poor quality"
|
||||
lines.append(
|
||||
f"{i}. **{dim_name}**: Score HIGH (0.7+) if: {pos_str}. "
|
||||
f"Score LOW (0.1-0.2) if: {neg_str}."
|
||||
)
|
||||
|
||||
if entity_description:
|
||||
lines.append(f"\nThis image should depict: {entity_description}")
|
||||
|
||||
template = {dim_name: 0.0 for dim_name in dimensions}
|
||||
lines.append(f"\nRespond with exactly this JSON:\n{json.dumps(template)}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _parse_json_scores(raw: str, expected_dims: list[str]) -> dict[str, float]:
|
||||
"""Extract JSON scores from Claude's response text."""
|
||||
_ensure_sdk()
|
||||
try:
|
||||
from claude_code_batch_sdk.parsing import parse_json_response
|
||||
data = parse_json_response(raw)
|
||||
if data and isinstance(data, dict):
|
||||
return {
|
||||
k: max(0.0, min(1.0, float(v)))
|
||||
for k, v in data.items()
|
||||
if k in expected_dims and isinstance(v, (int, float))
|
||||
}
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# Fallback: manual JSON extraction
|
||||
import re
|
||||
text = raw.strip()
|
||||
if text.startswith("```"):
|
||||
lines = text.split("\n")
|
||||
text = "\n".join(line for line in lines if not line.strip().startswith("```"))
|
||||
try:
|
||||
data = json.loads(text)
|
||||
if isinstance(data, dict):
|
||||
return {
|
||||
k: max(0.0, min(1.0, float(v)))
|
||||
for k, v in data.items()
|
||||
if k in expected_dims and isinstance(v, (int, float))
|
||||
}
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
pass
|
||||
|
||||
json_match = re.search(r"\{[^{}]+\}", text)
|
||||
if json_match:
|
||||
try:
|
||||
data = json.loads(json_match.group())
|
||||
if isinstance(data, dict):
|
||||
return {
|
||||
k: max(0.0, min(1.0, float(v)))
|
||||
for k, v in data.items()
|
||||
if k in expected_dims and isinstance(v, (int, float))
|
||||
}
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
pass
|
||||
|
||||
logger.warning("Failed to parse Claude scores from: %s", text[:200])
|
||||
return {}
|
||||
|
||||
|
||||
async def score_with_claude(
|
||||
image_path: str,
|
||||
dimensions: dict[str, dict[str, list[str]]],
|
||||
model_key: str = "claude-sonnet",
|
||||
entity_description: str = "",
|
||||
) -> dict[str, Any]:
|
||||
"""Score an image using Claude via claude-code-batch-sdk.
|
||||
|
||||
The ClaudeClient spawns the `claude` CLI as a subprocess — no API key needed.
|
||||
The image is read by Claude via the Read tool from the working directory.
|
||||
|
||||
Args:
|
||||
image_path: Absolute path to the image file on disk.
|
||||
dimensions: Dict of dimension_name → {positive: [...], negative: [...]}.
|
||||
model_key: One of 'claude-haiku', 'claude-sonnet', 'claude-opus'.
|
||||
entity_description: Optional description of what the image should depict.
|
||||
|
||||
Returns:
|
||||
Dict with 'scores' (per-dimension) and 'processing_time_ms'.
|
||||
"""
|
||||
_ensure_sdk()
|
||||
|
||||
model = MODEL_MAP.get(model_key)
|
||||
if not model:
|
||||
return {"scores": {}, "error": f"Unknown model: {model_key}", "processing_time_ms": 0}
|
||||
|
||||
try:
|
||||
from claude_code_batch_sdk import ClaudeClient
|
||||
except ImportError:
|
||||
return {
|
||||
"scores": {},
|
||||
"error": "claude-code-batch-sdk not found",
|
||||
"processing_time_ms": 0,
|
||||
}
|
||||
|
||||
image_file = Path(image_path)
|
||||
if not image_file.exists():
|
||||
return {"scores": {}, "error": f"Image not found: {image_path}", "processing_time_ms": 0}
|
||||
|
||||
prompt = _build_scoring_prompt(
|
||||
dimensions,
|
||||
entity_description=entity_description,
|
||||
image_filename=image_file.name,
|
||||
)
|
||||
|
||||
client = ClaudeClient(model=model, max_concurrent=1, timeout=180.0)
|
||||
|
||||
t0 = time.time()
|
||||
raw = await client.generate(
|
||||
system=SYSTEM_PROMPT,
|
||||
user=prompt,
|
||||
cwd=str(image_file.parent),
|
||||
allowed_tools=["Read"],
|
||||
)
|
||||
elapsed_ms = (time.time() - t0) * 1000
|
||||
|
||||
if raw is None:
|
||||
return {"scores": {}, "error": "Claude returned no response", "processing_time_ms": elapsed_ms}
|
||||
|
||||
scores = _parse_json_scores(raw, list(dimensions.keys()))
|
||||
return {"scores": scores, "processing_time_ms": elapsed_ms}
|
||||
|
||||
|
||||
async def score_with_claude_base64(
|
||||
image_base64: str,
|
||||
dimensions: dict[str, dict[str, list[str]]],
|
||||
model_key: str = "claude-sonnet",
|
||||
entity_description: str = "",
|
||||
) -> dict[str, Any]:
|
||||
"""Score a base64-encoded image using Claude.
|
||||
|
||||
Writes image to a temp file, scores via CLI, cleans up.
|
||||
"""
|
||||
import base64
|
||||
import tempfile
|
||||
|
||||
# Strip data URL prefix if present
|
||||
img_data = image_base64
|
||||
if img_data.startswith("data:"):
|
||||
img_data = img_data.split(",", 1)[1]
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
|
||||
tmp.write(base64.b64decode(img_data))
|
||||
tmp_path = tmp.name
|
||||
|
||||
try:
|
||||
return await score_with_claude(
|
||||
image_path=tmp_path,
|
||||
dimensions=dimensions,
|
||||
model_key=model_key,
|
||||
entity_description=entity_description,
|
||||
)
|
||||
finally:
|
||||
Path(tmp_path).unlink(missing_ok=True)
|
||||
Loading…
Add table
Reference in a new issue