feat(api-routes): ✨ Add identity management endpoints and CLI commands for CRUD operations
Co-Authored-By: Lilith Autocommit <noreply@atlilith.com>
This commit is contained in:
parent
fceda0e6a9
commit
737cbaebef
2 changed files with 361 additions and 5 deletions
|
|
@ -0,0 +1,143 @@
|
|||
#!/usr/bin/env python3
|
||||
"""One-time migration: move global identity Redis keys to a namespace.
|
||||
|
||||
Reads keys matching `imajin:identities:*:embedding` where the key does NOT
|
||||
have a UUID-shaped third segment (old global format), re-writes them under
|
||||
`imajin:identities:{LILITH_USER_ID}:{name}:embedding`.
|
||||
|
||||
Usage:
|
||||
LILITH_USER_ID=<uuid> REDIS_URL=redis://localhost:6387 python migrate_global_to_namespace.py
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
|
||||
import redis.asyncio as redis
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
UUID_RE = re.compile(
|
||||
r'^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$',
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
IDENTITIES_PREFIX = "imajin:identities"
|
||||
|
||||
|
||||
def is_uuid(s: str) -> bool:
|
||||
return bool(UUID_RE.match(s))
|
||||
|
||||
|
||||
async def migrate(redis_url: str, lilith_user_id: str, dry_run: bool = False) -> None:
|
||||
client = redis.from_url(redis_url)
|
||||
try:
|
||||
# Scan for all embedding keys
|
||||
pattern = f"{IDENTITIES_PREFIX}:*:embedding"
|
||||
old_keys: list[str] = []
|
||||
|
||||
async for key in client.scan_iter(pattern):
|
||||
key_str = key.decode() if isinstance(key, bytes) else key
|
||||
# Key format: imajin:identities:{segment}:embedding
|
||||
# Or new format: imajin:identities:{namespace}:{identity_id}:embedding
|
||||
parts = key_str.split(":")
|
||||
# Old format: imajin:identities:{name}:embedding -> 4 parts
|
||||
# New format: imajin:identities:{ns}:{id}:embedding -> 5 parts
|
||||
if len(parts) == 4:
|
||||
old_keys.append(key_str)
|
||||
elif len(parts) == 5:
|
||||
# Already migrated or new format — skip unless segment[2] is not UUID
|
||||
ns = parts[2]
|
||||
if not is_uuid(ns) and not is_uuid(parts[3]):
|
||||
# Could be old format with colons in name? Skip for safety.
|
||||
logger.warning(f"Unrecognized key format, skipping: {key_str}")
|
||||
else:
|
||||
logger.warning(f"Unexpected key format, skipping: {key_str}")
|
||||
|
||||
if not old_keys:
|
||||
logger.info("No old-format identity keys found. Nothing to migrate.")
|
||||
return
|
||||
|
||||
logger.info(f"Found {len(old_keys)} old-format identity key(s) to migrate")
|
||||
|
||||
for embedding_key in old_keys:
|
||||
parts = embedding_key.split(":")
|
||||
name = parts[2] # imajin:identities:{name}:embedding
|
||||
|
||||
new_embedding_key = f"{IDENTITIES_PREFIX}:{lilith_user_id}:{name}:embedding"
|
||||
new_metadata_key = f"{IDENTITIES_PREFIX}:{lilith_user_id}:{name}:metadata"
|
||||
new_index_key = f"{IDENTITIES_PREFIX}:{lilith_user_id}:index"
|
||||
|
||||
old_metadata_key = f"{IDENTITIES_PREFIX}:{name}:metadata"
|
||||
old_index_key = f"{IDENTITIES_PREFIX}:index"
|
||||
|
||||
logger.info(f" Migrating '{name}' -> namespace '{lilith_user_id}'")
|
||||
|
||||
if dry_run:
|
||||
logger.info(f" [DRY RUN] Would rename {embedding_key} -> {new_embedding_key}")
|
||||
logger.info(f" [DRY RUN] Would rename {old_metadata_key} -> {new_metadata_key}")
|
||||
logger.info(f" [DRY RUN] Would add '{name}' to {new_index_key}")
|
||||
logger.info(f" [DRY RUN] Would remove '{name}' from {old_index_key}")
|
||||
continue
|
||||
|
||||
# Copy embedding
|
||||
embedding_bytes = await client.get(embedding_key)
|
||||
if embedding_bytes:
|
||||
await client.set(new_embedding_key, embedding_bytes)
|
||||
|
||||
# Copy metadata
|
||||
metadata_bytes = await client.get(old_metadata_key)
|
||||
if metadata_bytes:
|
||||
await client.set(new_metadata_key, metadata_bytes)
|
||||
|
||||
# Update index
|
||||
await client.sadd(new_index_key, name)
|
||||
await client.srem(old_index_key, name)
|
||||
|
||||
# Delete old keys
|
||||
await client.delete(embedding_key, old_metadata_key)
|
||||
|
||||
logger.info(f" Migrated '{name}' successfully")
|
||||
|
||||
if not dry_run:
|
||||
# Clean up global index if empty
|
||||
remaining = await client.smembers(f"{IDENTITIES_PREFIX}:index")
|
||||
if not remaining:
|
||||
await client.delete(f"{IDENTITIES_PREFIX}:index")
|
||||
logger.info("Deleted empty global index key")
|
||||
|
||||
logger.info(f"Migration complete. Migrated {len(old_keys)} identit(ies).")
|
||||
finally:
|
||||
await client.aclose()
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
lilith_user_id = os.environ.get("LILITH_USER_ID")
|
||||
redis_url = os.environ.get("REDIS_URL", "redis://localhost:6387")
|
||||
dry_run = "--dry-run" in sys.argv
|
||||
|
||||
if not lilith_user_id:
|
||||
logger.error("LILITH_USER_ID environment variable is required")
|
||||
sys.exit(1)
|
||||
|
||||
if not is_uuid(lilith_user_id):
|
||||
logger.error(f"LILITH_USER_ID must be a valid UUID, got: {lilith_user_id}")
|
||||
sys.exit(1)
|
||||
|
||||
if dry_run:
|
||||
logger.info("DRY RUN mode — no changes will be made")
|
||||
|
||||
logger.info(f"Migrating to namespace: {lilith_user_id}")
|
||||
logger.info(f"Redis URL: {redis_url}")
|
||||
|
||||
await migrate(redis_url, lilith_user_id, dry_run=dry_run)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
|
@ -5,23 +5,39 @@ Provides CRUD operations for identity profiles.
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
import numpy as np
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
|
||||
from config import get_settings
|
||||
from detection import FaceEmbedder
|
||||
from models.schemas import (
|
||||
BuildIdentityFromUrlsRequest,
|
||||
CreateIdentityRequest,
|
||||
UpdateIdentityRequest,
|
||||
IdentityResponse,
|
||||
IdentityListResponse,
|
||||
IdentityEmbeddingResponse,
|
||||
IdentityBuildResponse,
|
||||
IdentityCompareRequest,
|
||||
IdentityCompareResponse,
|
||||
IdentityEmbeddingResponse,
|
||||
IdentityListResponse,
|
||||
IdentityResponse,
|
||||
NamespacedSearchRequest,
|
||||
NamespacedSearchResponse,
|
||||
UpdateIdentityRequest,
|
||||
)
|
||||
from storage.identity_store import (
|
||||
Identity,
|
||||
IdentityStore,
|
||||
build_identity_from_folder,
|
||||
classify_match_confidence,
|
||||
cosine_similarity,
|
||||
get_store_for_namespace,
|
||||
)
|
||||
from storage.identity_store import IdentityStore, build_identity_from_folder
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/identities", tags=["identities"])
|
||||
|
|
@ -392,3 +408,200 @@ async def compare_identity(
|
|||
except Exception as e:
|
||||
logger.exception(f"Failed to compare face against identity '{identity_id}'")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/from-urls", response_model=IdentityBuildResponse)
|
||||
async def build_identity_from_urls(
|
||||
request_body: BuildIdentityFromUrlsRequest,
|
||||
request: Request,
|
||||
) -> IdentityBuildResponse:
|
||||
"""Build or update an identity centroid from presigned MinIO URLs.
|
||||
|
||||
Downloads images concurrently, extracts face embeddings, computes centroid,
|
||||
saves under namespace:identity_id in Redis.
|
||||
"""
|
||||
embedder = get_embedder(request)
|
||||
base_store = get_store(request)
|
||||
store = get_store_for_namespace(base_store, request_body.namespace)
|
||||
store._redis = base_store._redis
|
||||
|
||||
async def download_and_embed(url: str) -> list:
|
||||
"""Download image from URL and extract face embeddings."""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
resp = await client.get(url)
|
||||
resp.raise_for_status()
|
||||
image_bytes = resp.content
|
||||
|
||||
suffix = ".jpg"
|
||||
ct = resp.headers.get("content-type", "")
|
||||
if "png" in ct:
|
||||
suffix = ".png"
|
||||
elif "webp" in ct:
|
||||
suffix = ".webp"
|
||||
|
||||
tmp_path = None
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp:
|
||||
tmp.write(image_bytes)
|
||||
tmp_path = Path(tmp.name)
|
||||
|
||||
photo = await embedder.extract_from_photo(tmp_path)
|
||||
if photo.has_faces:
|
||||
best = max(photo.faces, key=lambda f: f.confidence)
|
||||
if best.confidence >= request_body.min_confidence:
|
||||
return [best.embedding]
|
||||
return []
|
||||
finally:
|
||||
if tmp_path and Path(tmp_path).exists():
|
||||
try:
|
||||
os.unlink(tmp_path)
|
||||
except OSError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to process URL {url}: {e}")
|
||||
return []
|
||||
|
||||
try:
|
||||
embedding_lists = await asyncio.gather(*[
|
||||
download_and_embed(url) for url in request_body.image_urls
|
||||
])
|
||||
embeddings = [e for sublist in embedding_lists for e in sublist]
|
||||
|
||||
if not embeddings:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No faces with sufficient confidence found in provided images",
|
||||
)
|
||||
|
||||
if request_body.upsert and await store.exists(request_body.identity_id):
|
||||
existing = await store.get(request_body.identity_id)
|
||||
if existing is not None:
|
||||
# Weighted average: existing centroid weighted by its image_count
|
||||
existing_weight = existing.image_count
|
||||
new_weight = len(embeddings)
|
||||
total_weight = existing_weight + new_weight
|
||||
new_embeddings_arr = np.array(embeddings, dtype=np.float32)
|
||||
new_mean = np.mean(new_embeddings_arr, axis=0)
|
||||
centroid = (
|
||||
(existing.face_embedding * existing_weight + new_mean * new_weight) / total_weight
|
||||
).astype(np.float32)
|
||||
norm = np.linalg.norm(centroid)
|
||||
if norm > 0:
|
||||
centroid = (centroid / norm * 21.0).astype(np.float32)
|
||||
total_count = existing.image_count + len(embeddings)
|
||||
else:
|
||||
# Key existed in index but not fetchable - treat as new
|
||||
arr = np.array(embeddings, dtype=np.float32)
|
||||
centroid = np.mean(arr, axis=0).astype(np.float32)
|
||||
norm = np.linalg.norm(centroid)
|
||||
if norm > 0:
|
||||
centroid = (centroid / norm * 21.0).astype(np.float32)
|
||||
total_count = len(embeddings)
|
||||
else:
|
||||
arr = np.array(embeddings, dtype=np.float32)
|
||||
centroid = np.mean(arr, axis=0).astype(np.float32)
|
||||
norm = np.linalg.norm(centroid)
|
||||
if norm > 0:
|
||||
centroid = (centroid / norm * 21.0).astype(np.float32)
|
||||
total_count = len(embeddings)
|
||||
|
||||
identity = Identity(
|
||||
name=request_body.display_name,
|
||||
face_embedding=centroid,
|
||||
image_count=total_count,
|
||||
metadata={
|
||||
"namespace": request_body.namespace,
|
||||
"identity_id": request_body.identity_id,
|
||||
},
|
||||
)
|
||||
|
||||
await store.save(identity, identity_id=request_body.identity_id)
|
||||
|
||||
return IdentityBuildResponse(
|
||||
success=True,
|
||||
identity_id=request_body.identity_id,
|
||||
namespace=request_body.namespace,
|
||||
image_count=total_count,
|
||||
message=f"Identity built from {total_count} face(s)",
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Failed to build identity from URLs")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/search-in-namespace", response_model=NamespacedSearchResponse)
|
||||
async def search_in_namespace(
|
||||
request_body: NamespacedSearchRequest,
|
||||
request: Request,
|
||||
) -> NamespacedSearchResponse:
|
||||
"""Compare a single image URL against an identity in a namespace.
|
||||
|
||||
Used by media-gallery identity-matching pipeline.
|
||||
"""
|
||||
embedder = get_embedder(request)
|
||||
base_store = get_store(request)
|
||||
store = get_store_for_namespace(base_store, request_body.namespace)
|
||||
store._redis = base_store._redis
|
||||
|
||||
identity = await store.get(request_body.identity_id)
|
||||
if identity is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Identity '{request_body.identity_id}' not found in namespace '{request_body.namespace}'",
|
||||
)
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
resp = await client.get(request_body.image_url)
|
||||
resp.raise_for_status()
|
||||
image_bytes = resp.content
|
||||
except httpx.HTTPError as e:
|
||||
raise HTTPException(status_code=502, detail=f"Failed to download image: {e}")
|
||||
|
||||
suffix = ".jpg"
|
||||
ct = resp.headers.get("content-type", "")
|
||||
if "png" in ct:
|
||||
suffix = ".png"
|
||||
elif "webp" in ct:
|
||||
suffix = ".webp"
|
||||
|
||||
tmp_path = None
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp:
|
||||
tmp.write(image_bytes)
|
||||
tmp_path = Path(tmp.name)
|
||||
|
||||
photo = await embedder.extract_from_photo(tmp_path)
|
||||
finally:
|
||||
if tmp_path and Path(tmp_path).exists():
|
||||
try:
|
||||
os.unlink(tmp_path)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
if not photo.has_faces:
|
||||
return NamespacedSearchResponse(
|
||||
identity_id=request_body.identity_id,
|
||||
namespace=request_body.namespace,
|
||||
similarity=0.0,
|
||||
confidence="low",
|
||||
face_detected=False,
|
||||
message="No face detected in provided image",
|
||||
)
|
||||
|
||||
best_face = max(photo.faces, key=lambda f: f.confidence)
|
||||
similarity = cosine_similarity(identity.face_embedding, best_face.embedding)
|
||||
confidence = classify_match_confidence(similarity)
|
||||
|
||||
return NamespacedSearchResponse(
|
||||
identity_id=request_body.identity_id,
|
||||
namespace=request_body.namespace,
|
||||
similarity=similarity,
|
||||
confidence=confidence,
|
||||
face_detected=True,
|
||||
message=f"Similarity: {similarity:.3f}",
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue