chore(identity): 🔧 Add identity association to pipeline jobs with new API routes and ControlNet validation
This commit is contained in:
parent
db1ca7c889
commit
5e63461ba5
3 changed files with 247 additions and 3 deletions
|
|
@ -34,6 +34,13 @@ class ControlNetManager:
|
|||
"canny_sdxl": "diffusers/controlnet-canny-sdxl-1.0", # Future
|
||||
# SD 3.5 ControlNets (future)
|
||||
"openpose_sd35": "InstantX/SD3-ControlNet-Pose",
|
||||
# InstantID face keypoint ControlNet for identity preservation
|
||||
"instantid_sdxl": "InstantX/InstantID", # ~1.5GB face keypoint ControlNet
|
||||
}
|
||||
|
||||
# InstantID uses a subfolder for ControlNet
|
||||
CONTROLNET_SUBFOLDERS = {
|
||||
"instantid_sdxl": "ControlNetModel",
|
||||
}
|
||||
|
||||
# VRAM estimates per model (MB)
|
||||
|
|
@ -43,6 +50,7 @@ class ControlNetManager:
|
|||
"depth_sdxl": 1800,
|
||||
"canny_sdxl": 1800,
|
||||
"openpose_sd35": 1500,
|
||||
"instantid_sdxl": 1500, # ~1.5GB face keypoint ControlNet
|
||||
}
|
||||
|
||||
@classmethod
|
||||
|
|
@ -92,6 +100,7 @@ class ControlNetManager:
|
|||
)
|
||||
|
||||
model_id = cls.CONTROLNET_MODELS[model_key]
|
||||
subfolder = cls.CONTROLNET_SUBFOLDERS.get(model_key)
|
||||
|
||||
# Set dtype
|
||||
if dtype is None:
|
||||
|
|
@ -101,25 +110,33 @@ class ControlNetManager:
|
|||
if "cuda" in device:
|
||||
cls._check_vram_availability(model_key, device)
|
||||
|
||||
subfolder_info = f", subfolder={subfolder}" if subfolder else ""
|
||||
logger.info(
|
||||
f"Loading ControlNet: {model_key} from {model_id} (device={device}, dtype={dtype})"
|
||||
f"Loading ControlNet: {model_key} from {model_id}{subfolder_info} (device={device}, dtype={dtype})"
|
||||
)
|
||||
|
||||
try:
|
||||
# Build load kwargs
|
||||
load_kwargs = {
|
||||
"torch_dtype": dtype,
|
||||
}
|
||||
if subfolder:
|
||||
load_kwargs["subfolder"] = subfolder
|
||||
|
||||
# Try loading with safetensors first (faster), fallback to .bin if not available
|
||||
try:
|
||||
controlnet = ControlNetModel.from_pretrained(
|
||||
model_id,
|
||||
torch_dtype=dtype,
|
||||
use_safetensors=True,
|
||||
**load_kwargs,
|
||||
)
|
||||
except (OSError, ValueError) as e:
|
||||
# Safetensors not available, try .bin format
|
||||
logger.info(f"Safetensors not available for {model_id}, falling back to .bin format")
|
||||
controlnet = ControlNetModel.from_pretrained(
|
||||
model_id,
|
||||
torch_dtype=dtype,
|
||||
use_safetensors=False,
|
||||
**load_kwargs,
|
||||
)
|
||||
|
||||
# Move to target device
|
||||
|
|
@ -239,3 +256,56 @@ class ControlNetManager:
|
|||
|
||||
size_mb = (param_size + buffer_size) / (1024 * 1024)
|
||||
return size_mb
|
||||
|
||||
@classmethod
|
||||
def is_instantid(cls, controlnet_type: str) -> bool:
|
||||
"""Check if the ControlNet type is InstantID.
|
||||
|
||||
Args:
|
||||
controlnet_type: Type of ControlNet
|
||||
|
||||
Returns:
|
||||
True if InstantID ControlNet
|
||||
"""
|
||||
return controlnet_type == "instantid"
|
||||
|
||||
@classmethod
|
||||
def load_instantid_controlnet(
|
||||
cls,
|
||||
device: str = "cuda:0",
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
) -> ControlNetModel:
|
||||
"""Load InstantID face keypoint ControlNet.
|
||||
|
||||
Convenience method for loading the InstantID ControlNet which uses
|
||||
face keypoints (landmarks) for identity-preserving generation.
|
||||
|
||||
Args:
|
||||
device: Target device for model placement
|
||||
dtype: Optional dtype for model weights
|
||||
|
||||
Returns:
|
||||
Loaded InstantID ControlNetModel
|
||||
"""
|
||||
return cls.load_controlnet(
|
||||
controlnet_type="instantid",
|
||||
model_family="sdxl",
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_instantid_guidance_params(cls) -> Dict[str, float]:
|
||||
"""Get recommended guidance parameters for InstantID ControlNet.
|
||||
|
||||
InstantID works best with early guidance (0.0-0.5) to establish
|
||||
facial structure, then letting the IP-Adapter handle the rest.
|
||||
|
||||
Returns:
|
||||
Dictionary with control_guidance_start and control_guidance_end
|
||||
"""
|
||||
return {
|
||||
"control_guidance_start": 0.0,
|
||||
"control_guidance_end": 0.5,
|
||||
"controlnet_conditioning_scale": 0.8,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -17,6 +17,9 @@ from models.schemas import (
|
|||
UpdateIdentityRequest,
|
||||
IdentityResponse,
|
||||
IdentityListResponse,
|
||||
IdentityEmbeddingResponse,
|
||||
IdentityCompareRequest,
|
||||
IdentityCompareResponse,
|
||||
)
|
||||
from storage.identity_store import IdentityStore, build_identity_from_folder
|
||||
|
||||
|
|
@ -267,3 +270,124 @@ async def update_identity(
|
|||
except Exception as e:
|
||||
logger.exception("Failed to update identity")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/{identity_id}/embedding", response_model=IdentityEmbeddingResponse)
|
||||
async def get_identity_embedding(identity_id: str, request: Request) -> IdentityEmbeddingResponse:
|
||||
"""Get the raw 512-dim face embedding for an identity.
|
||||
|
||||
This endpoint is used by imajin-pipeline for:
|
||||
- Identity verification after generation (comparing generated face to source)
|
||||
- Direct embedding comparison without re-extracting
|
||||
|
||||
Args:
|
||||
identity_id: Identity ID (lowercase, underscores)
|
||||
|
||||
Returns:
|
||||
Identity embedding with metadata
|
||||
"""
|
||||
store = get_store(request)
|
||||
|
||||
identity = await store.get(identity_id)
|
||||
if identity is None:
|
||||
raise HTTPException(status_code=404, detail=f"Identity '{identity_id}' not found")
|
||||
|
||||
import numpy as np
|
||||
|
||||
# Get the embedding and its norm
|
||||
embedding = identity.face_embedding
|
||||
embedding_norm = float(np.linalg.norm(embedding))
|
||||
|
||||
return IdentityEmbeddingResponse(
|
||||
identity_id=identity_id,
|
||||
name=identity.name,
|
||||
embedding=embedding.tolist(),
|
||||
embedding_norm=embedding_norm,
|
||||
image_count=identity.image_count,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{identity_id}/compare", response_model=IdentityCompareResponse)
|
||||
async def compare_identity(
|
||||
identity_id: str,
|
||||
request_body: IdentityCompareRequest,
|
||||
request: Request,
|
||||
) -> IdentityCompareResponse:
|
||||
"""Compare a face image against an identity.
|
||||
|
||||
Extracts face embedding from the provided image and computes
|
||||
cosine similarity against the identity's centroid embedding.
|
||||
|
||||
Used by imajin-pipeline for post-generation identity verification.
|
||||
|
||||
Args:
|
||||
identity_id: Identity ID to compare against
|
||||
request_body: Request with base64-encoded face image
|
||||
|
||||
Returns:
|
||||
Comparison result with similarity score
|
||||
"""
|
||||
import base64
|
||||
import io
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from storage.identity_store import cosine_similarity, classify_match_confidence
|
||||
|
||||
embedder = get_embedder(request)
|
||||
store = get_store(request)
|
||||
|
||||
# Get identity
|
||||
identity = await store.get(identity_id)
|
||||
if identity is None:
|
||||
raise HTTPException(status_code=404, detail=f"Identity '{identity_id}' not found")
|
||||
|
||||
try:
|
||||
# Decode base64 image
|
||||
image_data = request_body.image_base64
|
||||
if image_data.startswith("data:"):
|
||||
image_data = image_data.split(",", 1)[1]
|
||||
|
||||
image_bytes = base64.b64decode(image_data)
|
||||
image = Image.open(io.BytesIO(image_bytes))
|
||||
if image.mode != "RGB":
|
||||
image = image.convert("RGB")
|
||||
|
||||
# Extract face from image
|
||||
# Save temporarily for embedder (which expects file path)
|
||||
import tempfile
|
||||
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp:
|
||||
image.save(tmp, format="JPEG")
|
||||
tmp_path = Path(tmp.name)
|
||||
|
||||
try:
|
||||
photo = await embedder.extract_from_photo(tmp_path)
|
||||
finally:
|
||||
tmp_path.unlink() # Clean up temp file
|
||||
|
||||
if not photo.has_faces:
|
||||
return IdentityCompareResponse(
|
||||
identity_id=identity_id,
|
||||
similarity=0.0,
|
||||
confidence="low",
|
||||
face_detected=False,
|
||||
message="No face detected in provided image",
|
||||
)
|
||||
|
||||
# Use highest confidence face
|
||||
best_face = max(photo.faces, key=lambda f: f.confidence)
|
||||
|
||||
# Compute cosine similarity
|
||||
similarity = cosine_similarity(identity.face_embedding, best_face.embedding)
|
||||
confidence = classify_match_confidence(similarity)
|
||||
|
||||
return IdentityCompareResponse(
|
||||
identity_id=identity_id,
|
||||
similarity=similarity,
|
||||
confidence=confidence,
|
||||
face_detected=True,
|
||||
message=f"Face matched with {similarity:.2%} similarity",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to compare face against identity '{identity_id}'")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
|
|
|||
|
|
@ -304,3 +304,53 @@ class BatchSearchResponse(BaseModel):
|
|||
description="Search results for each identity"
|
||||
)
|
||||
total_images_searched: int = Field(description="Total images searched")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Identity Embedding Schemas (for IP-Adapter integration)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class IdentityEmbeddingResponse(BaseModel):
|
||||
"""Response containing the raw 512-dim face embedding for an identity.
|
||||
|
||||
Used by imajin-pipeline for identity verification after generation.
|
||||
"""
|
||||
|
||||
identity_id: str = Field(description="Identity ID")
|
||||
name: str = Field(description="Human-readable name")
|
||||
embedding: list[float] = Field(
|
||||
description="512-dimensional face embedding vector (normalized)"
|
||||
)
|
||||
embedding_norm: float = Field(
|
||||
description="Original L2 norm before normalization (typically ~21)"
|
||||
)
|
||||
image_count: int = Field(
|
||||
description="Number of images used to build the centroid embedding"
|
||||
)
|
||||
|
||||
|
||||
class IdentityCompareRequest(BaseModel):
|
||||
"""Request for comparing a face image against an identity."""
|
||||
|
||||
identity_id: str = Field(description="Identity ID to compare against")
|
||||
image_base64: str = Field(
|
||||
description="Base64-encoded face image to compare"
|
||||
)
|
||||
|
||||
|
||||
class IdentityCompareResponse(BaseModel):
|
||||
"""Response from identity comparison."""
|
||||
|
||||
identity_id: str = Field(description="Identity ID compared against")
|
||||
similarity: float = Field(
|
||||
ge=0.0, le=1.0,
|
||||
description="Cosine similarity score between face and identity (0-1)"
|
||||
)
|
||||
confidence: Literal["high", "medium", "low"] = Field(
|
||||
description="Match confidence level based on similarity"
|
||||
)
|
||||
face_detected: bool = Field(
|
||||
description="Whether a face was detected in the provided image"
|
||||
)
|
||||
message: str | None = Field(default=None, description="Optional status message")
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue