feat(image-pipeline): Introduce new image generation stage with generate endpoint and background inpainter logic

Co-Authored-By: Lilith Autocommit <noreply@atlilith.com>
This commit is contained in:
Claude Code 2026-03-31 07:09:27 -07:00
parent 636a3b4bd4
commit 5d1745211a
4 changed files with 193 additions and 79 deletions

View file

@ -671,13 +671,27 @@ async def get_generator(model_id: str, context: Optional[PipelineContext] = None
# Determine if IP-Adapter is needed for identity conditioning
needs_ip_adapter = context and context.identity_conditioning is not None
# Determine if InstantID ControlNet is needed for enhanced identity fidelity
# Determine if InstantID ControlNet is needed for enhanced identity fidelity.
# Guard: only load ControlNet if InsightFace is available — if it isn't,
# _prepare_instantid_conditioning() will disable InstantID at generation time
# and the ControlNet pipeline would be called with image=None, crashing diffusers.
needs_instantid = (
context
and context.identity_conditioning is not None
and context.identity_conditioning.enable_instantid
and context.identity_conditioning.instantid_image is not None
)
if needs_instantid:
try:
from insightface.app import FaceAnalysis as _FaceAnalysis # noqa: F401
except ImportError:
needs_instantid = False
logger.warning(
"InsightFace not available — InstantID ControlNet will not be loaded; "
"falling back to IP-Adapter-only identity conditioning"
)
if context and context.identity_conditioning is not None:
context.identity_conditioning.enable_instantid = False
if needs_instantid:
controlnet_models.append("instantid")
needs_controlnet = True # Ensure ControlNet pipeline is loaded

View file

@ -123,6 +123,11 @@ async def cleanup_service() -> None:
await queue.stop()
logger.info("GenerationQueue stopped")
background_inpainter = lifespan.get_state("background_inpainter")
if background_inpainter:
await background_inpainter.shutdown()
logger.info("BackgroundInpainter stopped")
job_storage = lifespan.get_state("job_storage")
if job_storage:
await job_storage.close()

View file

@ -597,53 +597,32 @@ async def repaint_background_async(body: RepaintBackgroundRequest) -> dict[str,
import io
import random
from PIL import Image, ImageFilter, ImageOps
from PIL import Image
try:
# Decode source image and snap to SDXL-safe dimensions
source_data = body.source_image
if source_data.startswith("data:"):
source_data = source_data.split(",", 1)[1]
source_bytes = base64.b64decode(source_data)
source_rgb = Image.open(io.BytesIO(source_bytes)).convert("RGB")
source_rgb = Image.open(io.BytesIO(base64.b64decode(source_data))).convert("RGB")
gen_w = max(512, min(1536, (source_rgb.width // 64) * 64))
gen_h = max(512, min(1536, (source_rgb.height // 64) * 64))
source_rgb = source_rgb.resize((gen_w, gen_h), Image.LANCZOS)
# Stage 1: BiRefNet segmentation on CPU executor
# repaint() acquires model-boss GPU lease + gpu_lock, runs BiRefNet CUDA
# segmentation and SDXL inpainting serially under the same lease
await job_storage.update_status(
job.id, StorageJobStatus.RUNNING, current_stage="segment_subject"
)
def _segment(img: Image.Image) -> Image.Image:
from rembg import new_session, remove as rembg_remove
session = new_session("birefnet-general", providers=["CPUExecutionProvider"])
rgba = rembg_remove(img, session=session) # type: ignore[return-value]
alpha = rgba.split()[3] # white=subject
bg_mask = ImageOps.invert(alpha) # white=background (to replace)
bg_mask = bg_mask.filter(ImageFilter.MaxFilter(11)) # dilate ~5px inward
bg_mask = bg_mask.filter(ImageFilter.GaussianBlur(radius=20)) # feather transition
return bg_mask.convert("L")
loop = asyncio.get_running_loop()
background_mask = await loop.run_in_executor(None, _segment, source_rgb)
# Stage 2: SDXL inpainting — regenerates background in full image context
await job_storage.update_status(
job.id, StorageJobStatus.RUNNING, current_stage="inpaint_background"
job.id, StorageJobStatus.RUNNING, current_stage="repaint"
)
bg_prompt, bg_negative = _enforce_rating(
body.maturity_rating, body.background_prompt, body.negative_prompt
)
# Prevent people from appearing in the generated background
bg_negative = (bg_negative or "") + ", person, man, woman, people, human, figure, crowd"
seed = body.seed if body.seed is not None else random.randint(0, 2**31 - 1)
result_image = await background_inpainter.repaint(
source_image=source_rgb,
background_mask=background_mask,
prompt=bg_prompt,
negative_prompt=bg_negative,
steps=body.steps,
@ -651,7 +630,6 @@ async def repaint_background_async(body: RepaintBackgroundRequest) -> dict[str,
seed=seed,
)
# Stage 3: Encode result
await job_storage.update_status(
job.id, StorageJobStatus.RUNNING, current_stage="encode"
)
@ -669,13 +647,13 @@ async def repaint_background_async(body: RepaintBackgroundRequest) -> dict[str,
"height": gen_h,
"output_base64": output_b64,
"seed": seed,
"stages_completed": 3,
"stages_completed": 2,
"total_duration_ms": 0,
"error": None,
}
await job_storage.set_result(job.id, result)
await job_storage.update_status(
job.id, StorageJobStatus.COMPLETED, stages_completed=3
job.id, StorageJobStatus.COMPLETED, stages_completed=2
)
except Exception as e:

View file

@ -1,12 +1,17 @@
"""SDXL inpainting-based background replacement.
"""SDXL inpainting-based background replacement with model-boss GPU coordination.
Uses AutoPipelineForInpainting with `diffusers/stable-diffusion-xl-1.0-inpainting-0.1`
to replace image backgrounds while preserving the subject seamlessly. The model
processes the full image context so lighting, shadows, and edge blending are
handled naturally no hard compositing artefacts.
Full GPU pipeline:
1. BiRefNet segmentation (CUDA ONNX) SOTA subject extraction + mask feathering
2. SDXL inpainting (AutoPipelineForInpainting) regenerates background in full
image context so lighting, shadows, and edge blending are handled by the model
GPU access is serialised via a shared asyncio.Lock that is also held by the
GenerationQueue worker, preventing VRAM conflicts between concurrent requests.
GPU access is coordinated via model-boss: `acquire_lease()` puts the request in
the priority queue and returns only when a GPU slot is available. This integrates
with all other GPU consumers on the system (inference, training, identity) rather
than grabbing VRAM unilaterally.
Both BiRefNet and SDXL inpainting run under the same model-boss lease, serialized
via the shared asyncio.Lock so they never overlap with GenerationQueue GPU work.
"""
from __future__ import annotations
@ -19,53 +24,71 @@ from PIL import Image
logger = logging.getLogger(__name__)
MODEL_ID = "diffusers/stable-diffusion-xl-1.0-inpainting-0.1"
INPAINTING_MODEL_ID = "diffusers/stable-diffusion-xl-1.0-inpainting-0.1"
INPAINTING_VRAM_MB = 6000
class BackgroundInpainter:
"""Lazy-loaded SDXL inpainting pipeline for background replacement.
"""BiRefNet segmentation + SDXL inpainting, fully coordinated via model-boss.
Thread-safe: all GPU work runs in a thread executor while the asyncio
event loop holds the shared gpu_lock, preventing overlap with the
GenerationQueue worker.
On the first repaint call:
- Acquires a GPU lease from model-boss (blocks in the priority queue until
a GPU slot with sufficient VRAM is available)
- Loads BiRefNet segmentation model (CUDA ONNX) on the leased GPU
- Loads SDXL inpainting pipeline (fp16) on the leased GPU
- Starts a heartbeat task to keep the lease alive
Subsequent calls reuse the loaded models. If the lease is evicted by
model-boss (higher-priority work or idle timeout), models reload on the
next request.
All GPU work runs in a thread executor under the shared gpu_lock, which
also serializes GenerationQueue jobs no VRAM contention possible.
"""
def __init__(self, gpu_lock: asyncio.Lock) -> None:
self._gpu_lock = gpu_lock
self._pipeline: Optional[object] = None
self._device = "cuda:0"
self._seg_session: Optional[object] = None
self._device: str = "cuda:0" # Updated after lease assignment
self._mb_client: Optional[object] = None
self._mb_lease_id: Optional[str] = None
self._heartbeat_task: Optional[asyncio.Task] = None
async def repaint(
self,
source_image: Image.Image,
background_mask: Image.Image,
prompt: str,
negative_prompt: str,
steps: int,
guidance_scale: float,
seed: int,
) -> Image.Image:
"""Replace the background in `source_image` using SDXL inpainting.
"""Segment subject then replace background using SDXL inpainting.
Acquires the shared gpu_lock so this never overlaps with GenerationQueue.
Acquires a model-boss lease on first call to register VRAM usage with
the system coordinator.
Args:
source_image: RGB source photo (subject to preserve).
background_mask: L-mode mask white (255) = replace, black (0) = keep.
source_image: RGB source photo. Subject is preserved; background replaced.
prompt: Background scene description.
negative_prompt: Negative prompt for the inpainting pass.
steps: Number of diffusion steps (20-40 recommended).
guidance_scale: CFG scale (7-8 recommended for photorealism).
steps: Number of diffusion steps.
guidance_scale: CFG scale.
seed: Deterministic seed.
Returns:
RGB PIL image with the background replaced.
RGB PIL image with background seamlessly replaced.
"""
loop = asyncio.get_running_loop()
async with self._gpu_lock:
if self._pipeline is None:
await self._acquire_and_load(loop)
return await loop.run_in_executor(
None,
self._run_sync,
source_image,
background_mask,
prompt,
negative_prompt,
steps,
@ -73,21 +96,140 @@ class BackgroundInpainter:
seed,
)
async def shutdown(self) -> None:
"""Cancel heartbeat and release model-boss lease."""
if self._heartbeat_task is not None:
self._heartbeat_task.cancel()
try:
await self._heartbeat_task
except asyncio.CancelledError:
pass
self._heartbeat_task = None
if self._mb_client is not None and self._mb_lease_id is not None:
try:
await self._mb_client.release_lease(self._mb_lease_id)
logger.info("Inpainting lease %s released", self._mb_lease_id)
except Exception as exc:
logger.warning("Failed to release inpainting lease: %r", exc)
self._mb_lease_id = None
# -------------------------------------------------------------------------
# Private — lease + load
# -------------------------------------------------------------------------
async def _acquire_and_load(self, loop: asyncio.AbstractEventLoop) -> None:
"""Acquire model-boss GPU lease, then load models in thread executor."""
from model_boss.client import InferenceClient
logger.info(
"Acquiring model-boss GPU lease for inpainting (vram=%dMB)", INPAINTING_VRAM_MB
)
self._mb_client = InferenceClient(
client_id="imajin-inpainting",
auto_start_services=False,
)
lease = await self._mb_client.acquire_lease(
model_id="imajin-pipeline:inpainting",
vram_mb=INPAINTING_VRAM_MB,
priority="high",
endpoint="inpainting",
)
self._mb_lease_id = lease["lease_id"]
gpu_index = lease["gpu_index"]
self._device = f"cuda:{gpu_index}"
logger.info(
"Inpainting lease acquired: GPU %d, lease=%s",
gpu_index, self._mb_lease_id,
)
self._heartbeat_task = asyncio.create_task(self._heartbeat_loop())
await loop.run_in_executor(None, self._load_sync)
async def _heartbeat_loop(self) -> None:
while True:
try:
await asyncio.sleep(10.0)
alive = await self._mb_client.heartbeat(self._mb_lease_id) # type: ignore[union-attr]
if not alive:
logger.warning(
"Inpainting lease %s evicted — models will reload on next request",
self._mb_lease_id,
)
self._pipeline = None
self._seg_session = None
self._mb_lease_id = None
break
except asyncio.CancelledError:
break
except Exception as exc:
logger.warning("Inpainting lease heartbeat failed: %r", exc)
# -------------------------------------------------------------------------
# Private — sync (runs in thread executor, under gpu_lock + lease)
# -------------------------------------------------------------------------
def _load_sync(self) -> None:
"""Load BiRefNet (CUDA ONNX) and SDXL inpainting pipeline."""
import torch
from diffusers import AutoPipelineForInpainting
from rembg import new_session
gpu_id = int(self._device.split(":")[-1])
logger.info("Loading BiRefNet segmentation model on %s", self._device)
self._seg_session = new_session(
"birefnet-general",
providers=[
("CUDAExecutionProvider", {"device_id": gpu_id}),
"CPUExecutionProvider",
],
)
logger.info("BiRefNet loaded on %s", self._device)
logger.info(
"Loading SDXL inpainting model: %s on %s (first use — may download ~6 GB)",
INPAINTING_MODEL_ID, self._device,
)
try:
pipeline = AutoPipelineForInpainting.from_pretrained(
INPAINTING_MODEL_ID,
torch_dtype=torch.float16,
variant="fp16",
)
except Exception:
logger.warning("fp16 variant unavailable, loading full-precision weights")
pipeline = AutoPipelineForInpainting.from_pretrained(
INPAINTING_MODEL_ID,
torch_dtype=torch.float16,
)
pipeline = pipeline.to(self._device)
pipeline.enable_attention_slicing()
pipeline.enable_vae_slicing()
self._pipeline = pipeline
logger.info("SDXL inpainting pipeline ready on %s", self._device)
def _run_sync(
self,
source_image: Image.Image,
background_mask: Image.Image,
prompt: str,
negative_prompt: str,
steps: int,
guidance_scale: float,
seed: int,
) -> Image.Image:
"""BiRefNet segmentation + SDXL inpainting (blocking, runs in thread executor)."""
import torch
from PIL import ImageFilter, ImageOps
from rembg import remove as rembg_remove
if self._pipeline is None:
self._load()
# BiRefNet → feathered background mask
rgba = rembg_remove(source_image, session=self._seg_session)
alpha = rgba.split()[3] # white = subject
bg_mask = ImageOps.invert(alpha) # white = background
bg_mask = bg_mask.filter(ImageFilter.MaxFilter(11)) # dilate ~5px inward
bg_mask = bg_mask.filter(ImageFilter.GaussianBlur(radius=20)) # feather transition
background_mask = bg_mask.convert("L")
# SDXL inpainting — regenerates background in full image context
generator = torch.Generator(device=self._device).manual_seed(seed)
result = self._pipeline( # type: ignore[misc]
prompt=prompt,
@ -100,28 +242,3 @@ class BackgroundInpainter:
generator=generator,
)
return result.images[0]
def _load(self) -> None:
import torch
from diffusers import AutoPipelineForInpainting
logger.info("Loading SDXL inpainting model: %s (first use — may download ~6 GB)", MODEL_ID)
try:
pipeline = AutoPipelineForInpainting.from_pretrained(
MODEL_ID,
torch_dtype=torch.float16,
variant="fp16",
)
except Exception:
# Some HF mirrors don't have fp16 variant files — fall back to full precision
logger.warning("fp16 variant unavailable, loading full-precision weights")
pipeline = AutoPipelineForInpainting.from_pretrained(
MODEL_ID,
torch_dtype=torch.float16,
)
pipeline = pipeline.to(self._device)
pipeline.enable_attention_slicing()
pipeline.enable_vae_slicing()
self._pipeline = pipeline
logger.info("SDXL inpainting model ready on %s", self._device)