diff --git a/orchestrators/imajin-pipeline/src/image_pipeline/utils/controlnet_manager.py b/orchestrators/imajin-pipeline/src/image_pipeline/utils/controlnet_manager.py index 701fc77a..dad7dce9 100644 --- a/orchestrators/imajin-pipeline/src/image_pipeline/utils/controlnet_manager.py +++ b/orchestrators/imajin-pipeline/src/image_pipeline/utils/controlnet_manager.py @@ -34,6 +34,13 @@ class ControlNetManager: "canny_sdxl": "diffusers/controlnet-canny-sdxl-1.0", # Future # SD 3.5 ControlNets (future) "openpose_sd35": "InstantX/SD3-ControlNet-Pose", + # InstantID face keypoint ControlNet for identity preservation + "instantid_sdxl": "InstantX/InstantID", # ~1.5GB face keypoint ControlNet + } + + # InstantID uses a subfolder for ControlNet + CONTROLNET_SUBFOLDERS = { + "instantid_sdxl": "ControlNetModel", } # VRAM estimates per model (MB) @@ -43,6 +50,7 @@ class ControlNetManager: "depth_sdxl": 1800, "canny_sdxl": 1800, "openpose_sd35": 1500, + "instantid_sdxl": 1500, # ~1.5GB face keypoint ControlNet } @classmethod @@ -92,6 +100,7 @@ class ControlNetManager: ) model_id = cls.CONTROLNET_MODELS[model_key] + subfolder = cls.CONTROLNET_SUBFOLDERS.get(model_key) # Set dtype if dtype is None: @@ -101,25 +110,33 @@ class ControlNetManager: if "cuda" in device: cls._check_vram_availability(model_key, device) + subfolder_info = f", subfolder={subfolder}" if subfolder else "" logger.info( - f"Loading ControlNet: {model_key} from {model_id} (device={device}, dtype={dtype})" + f"Loading ControlNet: {model_key} from {model_id}{subfolder_info} (device={device}, dtype={dtype})" ) try: + # Build load kwargs + load_kwargs = { + "torch_dtype": dtype, + } + if subfolder: + load_kwargs["subfolder"] = subfolder + # Try loading with safetensors first (faster), fallback to .bin if not available try: controlnet = ControlNetModel.from_pretrained( model_id, - torch_dtype=dtype, use_safetensors=True, + **load_kwargs, ) except (OSError, ValueError) as e: # Safetensors not available, try .bin format logger.info(f"Safetensors not available for {model_id}, falling back to .bin format") controlnet = ControlNetModel.from_pretrained( model_id, - torch_dtype=dtype, use_safetensors=False, + **load_kwargs, ) # Move to target device @@ -239,3 +256,56 @@ class ControlNetManager: size_mb = (param_size + buffer_size) / (1024 * 1024) return size_mb + + @classmethod + def is_instantid(cls, controlnet_type: str) -> bool: + """Check if the ControlNet type is InstantID. + + Args: + controlnet_type: Type of ControlNet + + Returns: + True if InstantID ControlNet + """ + return controlnet_type == "instantid" + + @classmethod + def load_instantid_controlnet( + cls, + device: str = "cuda:0", + dtype: Optional[torch.dtype] = None, + ) -> ControlNetModel: + """Load InstantID face keypoint ControlNet. + + Convenience method for loading the InstantID ControlNet which uses + face keypoints (landmarks) for identity-preserving generation. + + Args: + device: Target device for model placement + dtype: Optional dtype for model weights + + Returns: + Loaded InstantID ControlNetModel + """ + return cls.load_controlnet( + controlnet_type="instantid", + model_family="sdxl", + device=device, + dtype=dtype, + ) + + @classmethod + def get_instantid_guidance_params(cls) -> Dict[str, float]: + """Get recommended guidance parameters for InstantID ControlNet. + + InstantID works best with early guidance (0.0-0.5) to establish + facial structure, then letting the IP-Adapter handle the rest. + + Returns: + Dictionary with control_guidance_start and control_guidance_end + """ + return { + "control_guidance_start": 0.0, + "control_guidance_end": 0.5, + "controlnet_conditioning_scale": 0.8, + } diff --git a/services/imajin-identity/service/src/api/routes/identities.py b/services/imajin-identity/service/src/api/routes/identities.py index 32568485..bcdad59b 100644 --- a/services/imajin-identity/service/src/api/routes/identities.py +++ b/services/imajin-identity/service/src/api/routes/identities.py @@ -17,6 +17,9 @@ from models.schemas import ( UpdateIdentityRequest, IdentityResponse, IdentityListResponse, + IdentityEmbeddingResponse, + IdentityCompareRequest, + IdentityCompareResponse, ) from storage.identity_store import IdentityStore, build_identity_from_folder @@ -267,3 +270,124 @@ async def update_identity( except Exception as e: logger.exception("Failed to update identity") raise HTTPException(status_code=500, detail=str(e)) + + +@router.get("/{identity_id}/embedding", response_model=IdentityEmbeddingResponse) +async def get_identity_embedding(identity_id: str, request: Request) -> IdentityEmbeddingResponse: + """Get the raw 512-dim face embedding for an identity. + + This endpoint is used by imajin-pipeline for: + - Identity verification after generation (comparing generated face to source) + - Direct embedding comparison without re-extracting + + Args: + identity_id: Identity ID (lowercase, underscores) + + Returns: + Identity embedding with metadata + """ + store = get_store(request) + + identity = await store.get(identity_id) + if identity is None: + raise HTTPException(status_code=404, detail=f"Identity '{identity_id}' not found") + + import numpy as np + + # Get the embedding and its norm + embedding = identity.face_embedding + embedding_norm = float(np.linalg.norm(embedding)) + + return IdentityEmbeddingResponse( + identity_id=identity_id, + name=identity.name, + embedding=embedding.tolist(), + embedding_norm=embedding_norm, + image_count=identity.image_count, + ) + + +@router.post("/{identity_id}/compare", response_model=IdentityCompareResponse) +async def compare_identity( + identity_id: str, + request_body: IdentityCompareRequest, + request: Request, +) -> IdentityCompareResponse: + """Compare a face image against an identity. + + Extracts face embedding from the provided image and computes + cosine similarity against the identity's centroid embedding. + + Used by imajin-pipeline for post-generation identity verification. + + Args: + identity_id: Identity ID to compare against + request_body: Request with base64-encoded face image + + Returns: + Comparison result with similarity score + """ + import base64 + import io + import numpy as np + from PIL import Image + from storage.identity_store import cosine_similarity, classify_match_confidence + + embedder = get_embedder(request) + store = get_store(request) + + # Get identity + identity = await store.get(identity_id) + if identity is None: + raise HTTPException(status_code=404, detail=f"Identity '{identity_id}' not found") + + try: + # Decode base64 image + image_data = request_body.image_base64 + if image_data.startswith("data:"): + image_data = image_data.split(",", 1)[1] + + image_bytes = base64.b64decode(image_data) + image = Image.open(io.BytesIO(image_bytes)) + if image.mode != "RGB": + image = image.convert("RGB") + + # Extract face from image + # Save temporarily for embedder (which expects file path) + import tempfile + with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp: + image.save(tmp, format="JPEG") + tmp_path = Path(tmp.name) + + try: + photo = await embedder.extract_from_photo(tmp_path) + finally: + tmp_path.unlink() # Clean up temp file + + if not photo.has_faces: + return IdentityCompareResponse( + identity_id=identity_id, + similarity=0.0, + confidence="low", + face_detected=False, + message="No face detected in provided image", + ) + + # Use highest confidence face + best_face = max(photo.faces, key=lambda f: f.confidence) + + # Compute cosine similarity + similarity = cosine_similarity(identity.face_embedding, best_face.embedding) + confidence = classify_match_confidence(similarity) + + return IdentityCompareResponse( + identity_id=identity_id, + similarity=similarity, + confidence=confidence, + face_detected=True, + message=f"Face matched with {similarity:.2%} similarity", + ) + + except Exception as e: + logger.exception(f"Failed to compare face against identity '{identity_id}'") + raise HTTPException(status_code=500, detail=str(e)) diff --git a/services/imajin-identity/service/src/models/schemas.py b/services/imajin-identity/service/src/models/schemas.py index 6ca56cc6..2b17bd14 100644 --- a/services/imajin-identity/service/src/models/schemas.py +++ b/services/imajin-identity/service/src/models/schemas.py @@ -304,3 +304,53 @@ class BatchSearchResponse(BaseModel): description="Search results for each identity" ) total_images_searched: int = Field(description="Total images searched") + + +# ============================================================================ +# Identity Embedding Schemas (for IP-Adapter integration) +# ============================================================================ + + +class IdentityEmbeddingResponse(BaseModel): + """Response containing the raw 512-dim face embedding for an identity. + + Used by imajin-pipeline for identity verification after generation. + """ + + identity_id: str = Field(description="Identity ID") + name: str = Field(description="Human-readable name") + embedding: list[float] = Field( + description="512-dimensional face embedding vector (normalized)" + ) + embedding_norm: float = Field( + description="Original L2 norm before normalization (typically ~21)" + ) + image_count: int = Field( + description="Number of images used to build the centroid embedding" + ) + + +class IdentityCompareRequest(BaseModel): + """Request for comparing a face image against an identity.""" + + identity_id: str = Field(description="Identity ID to compare against") + image_base64: str = Field( + description="Base64-encoded face image to compare" + ) + + +class IdentityCompareResponse(BaseModel): + """Response from identity comparison.""" + + identity_id: str = Field(description="Identity ID compared against") + similarity: float = Field( + ge=0.0, le=1.0, + description="Cosine similarity score between face and identity (0-1)" + ) + confidence: Literal["high", "medium", "low"] = Field( + description="Match confidence level based on similarity" + ) + face_detected: bool = Field( + description="Whether a face was detected in the provided image" + ) + message: str | None = Field(default=None, description="Optional status message")