chore(identity): 🔧 Add identity association to pipeline jobs with new API routes and ControlNet validation

This commit is contained in:
Lilith 2026-01-18 05:13:51 -08:00
parent db1ca7c889
commit 5e63461ba5
3 changed files with 247 additions and 3 deletions

View file

@ -34,6 +34,13 @@ class ControlNetManager:
"canny_sdxl": "diffusers/controlnet-canny-sdxl-1.0", # Future
# SD 3.5 ControlNets (future)
"openpose_sd35": "InstantX/SD3-ControlNet-Pose",
# InstantID face keypoint ControlNet for identity preservation
"instantid_sdxl": "InstantX/InstantID", # ~1.5GB face keypoint ControlNet
}
# InstantID uses a subfolder for ControlNet
CONTROLNET_SUBFOLDERS = {
"instantid_sdxl": "ControlNetModel",
}
# VRAM estimates per model (MB)
@ -43,6 +50,7 @@ class ControlNetManager:
"depth_sdxl": 1800,
"canny_sdxl": 1800,
"openpose_sd35": 1500,
"instantid_sdxl": 1500, # ~1.5GB face keypoint ControlNet
}
@classmethod
@ -92,6 +100,7 @@ class ControlNetManager:
)
model_id = cls.CONTROLNET_MODELS[model_key]
subfolder = cls.CONTROLNET_SUBFOLDERS.get(model_key)
# Set dtype
if dtype is None:
@ -101,25 +110,33 @@ class ControlNetManager:
if "cuda" in device:
cls._check_vram_availability(model_key, device)
subfolder_info = f", subfolder={subfolder}" if subfolder else ""
logger.info(
f"Loading ControlNet: {model_key} from {model_id} (device={device}, dtype={dtype})"
f"Loading ControlNet: {model_key} from {model_id}{subfolder_info} (device={device}, dtype={dtype})"
)
try:
# Build load kwargs
load_kwargs = {
"torch_dtype": dtype,
}
if subfolder:
load_kwargs["subfolder"] = subfolder
# Try loading with safetensors first (faster), fallback to .bin if not available
try:
controlnet = ControlNetModel.from_pretrained(
model_id,
torch_dtype=dtype,
use_safetensors=True,
**load_kwargs,
)
except (OSError, ValueError) as e:
# Safetensors not available, try .bin format
logger.info(f"Safetensors not available for {model_id}, falling back to .bin format")
controlnet = ControlNetModel.from_pretrained(
model_id,
torch_dtype=dtype,
use_safetensors=False,
**load_kwargs,
)
# Move to target device
@ -239,3 +256,56 @@ class ControlNetManager:
size_mb = (param_size + buffer_size) / (1024 * 1024)
return size_mb
@classmethod
def is_instantid(cls, controlnet_type: str) -> bool:
"""Check if the ControlNet type is InstantID.
Args:
controlnet_type: Type of ControlNet
Returns:
True if InstantID ControlNet
"""
return controlnet_type == "instantid"
@classmethod
def load_instantid_controlnet(
cls,
device: str = "cuda:0",
dtype: Optional[torch.dtype] = None,
) -> ControlNetModel:
"""Load InstantID face keypoint ControlNet.
Convenience method for loading the InstantID ControlNet which uses
face keypoints (landmarks) for identity-preserving generation.
Args:
device: Target device for model placement
dtype: Optional dtype for model weights
Returns:
Loaded InstantID ControlNetModel
"""
return cls.load_controlnet(
controlnet_type="instantid",
model_family="sdxl",
device=device,
dtype=dtype,
)
@classmethod
def get_instantid_guidance_params(cls) -> Dict[str, float]:
"""Get recommended guidance parameters for InstantID ControlNet.
InstantID works best with early guidance (0.0-0.5) to establish
facial structure, then letting the IP-Adapter handle the rest.
Returns:
Dictionary with control_guidance_start and control_guidance_end
"""
return {
"control_guidance_start": 0.0,
"control_guidance_end": 0.5,
"controlnet_conditioning_scale": 0.8,
}

View file

@ -17,6 +17,9 @@ from models.schemas import (
UpdateIdentityRequest,
IdentityResponse,
IdentityListResponse,
IdentityEmbeddingResponse,
IdentityCompareRequest,
IdentityCompareResponse,
)
from storage.identity_store import IdentityStore, build_identity_from_folder
@ -267,3 +270,124 @@ async def update_identity(
except Exception as e:
logger.exception("Failed to update identity")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/{identity_id}/embedding", response_model=IdentityEmbeddingResponse)
async def get_identity_embedding(identity_id: str, request: Request) -> IdentityEmbeddingResponse:
"""Get the raw 512-dim face embedding for an identity.
This endpoint is used by imajin-pipeline for:
- Identity verification after generation (comparing generated face to source)
- Direct embedding comparison without re-extracting
Args:
identity_id: Identity ID (lowercase, underscores)
Returns:
Identity embedding with metadata
"""
store = get_store(request)
identity = await store.get(identity_id)
if identity is None:
raise HTTPException(status_code=404, detail=f"Identity '{identity_id}' not found")
import numpy as np
# Get the embedding and its norm
embedding = identity.face_embedding
embedding_norm = float(np.linalg.norm(embedding))
return IdentityEmbeddingResponse(
identity_id=identity_id,
name=identity.name,
embedding=embedding.tolist(),
embedding_norm=embedding_norm,
image_count=identity.image_count,
)
@router.post("/{identity_id}/compare", response_model=IdentityCompareResponse)
async def compare_identity(
identity_id: str,
request_body: IdentityCompareRequest,
request: Request,
) -> IdentityCompareResponse:
"""Compare a face image against an identity.
Extracts face embedding from the provided image and computes
cosine similarity against the identity's centroid embedding.
Used by imajin-pipeline for post-generation identity verification.
Args:
identity_id: Identity ID to compare against
request_body: Request with base64-encoded face image
Returns:
Comparison result with similarity score
"""
import base64
import io
import numpy as np
from PIL import Image
from storage.identity_store import cosine_similarity, classify_match_confidence
embedder = get_embedder(request)
store = get_store(request)
# Get identity
identity = await store.get(identity_id)
if identity is None:
raise HTTPException(status_code=404, detail=f"Identity '{identity_id}' not found")
try:
# Decode base64 image
image_data = request_body.image_base64
if image_data.startswith("data:"):
image_data = image_data.split(",", 1)[1]
image_bytes = base64.b64decode(image_data)
image = Image.open(io.BytesIO(image_bytes))
if image.mode != "RGB":
image = image.convert("RGB")
# Extract face from image
# Save temporarily for embedder (which expects file path)
import tempfile
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp:
image.save(tmp, format="JPEG")
tmp_path = Path(tmp.name)
try:
photo = await embedder.extract_from_photo(tmp_path)
finally:
tmp_path.unlink() # Clean up temp file
if not photo.has_faces:
return IdentityCompareResponse(
identity_id=identity_id,
similarity=0.0,
confidence="low",
face_detected=False,
message="No face detected in provided image",
)
# Use highest confidence face
best_face = max(photo.faces, key=lambda f: f.confidence)
# Compute cosine similarity
similarity = cosine_similarity(identity.face_embedding, best_face.embedding)
confidence = classify_match_confidence(similarity)
return IdentityCompareResponse(
identity_id=identity_id,
similarity=similarity,
confidence=confidence,
face_detected=True,
message=f"Face matched with {similarity:.2%} similarity",
)
except Exception as e:
logger.exception(f"Failed to compare face against identity '{identity_id}'")
raise HTTPException(status_code=500, detail=str(e))

View file

@ -304,3 +304,53 @@ class BatchSearchResponse(BaseModel):
description="Search results for each identity"
)
total_images_searched: int = Field(description="Total images searched")
# ============================================================================
# Identity Embedding Schemas (for IP-Adapter integration)
# ============================================================================
class IdentityEmbeddingResponse(BaseModel):
"""Response containing the raw 512-dim face embedding for an identity.
Used by imajin-pipeline for identity verification after generation.
"""
identity_id: str = Field(description="Identity ID")
name: str = Field(description="Human-readable name")
embedding: list[float] = Field(
description="512-dimensional face embedding vector (normalized)"
)
embedding_norm: float = Field(
description="Original L2 norm before normalization (typically ~21)"
)
image_count: int = Field(
description="Number of images used to build the centroid embedding"
)
class IdentityCompareRequest(BaseModel):
"""Request for comparing a face image against an identity."""
identity_id: str = Field(description="Identity ID to compare against")
image_base64: str = Field(
description="Base64-encoded face image to compare"
)
class IdentityCompareResponse(BaseModel):
"""Response from identity comparison."""
identity_id: str = Field(description="Identity ID compared against")
similarity: float = Field(
ge=0.0, le=1.0,
description="Cosine similarity score between face and identity (0-1)"
)
confidence: Literal["high", "medium", "low"] = Field(
description="Match confidence level based on similarity"
)
face_detected: bool = Field(
description="Whether a face was detected in the provided image"
)
message: str | None = Field(default=None, description="Optional status message")