feat(image-pipeline): ✨ Introduce new image generation stage with generate endpoint and background inpainter logic
Co-Authored-By: Lilith Autocommit <noreply@atlilith.com>
This commit is contained in:
parent
636a3b4bd4
commit
5d1745211a
4 changed files with 193 additions and 79 deletions
|
|
@ -671,13 +671,27 @@ async def get_generator(model_id: str, context: Optional[PipelineContext] = None
|
|||
# Determine if IP-Adapter is needed for identity conditioning
|
||||
needs_ip_adapter = context and context.identity_conditioning is not None
|
||||
|
||||
# Determine if InstantID ControlNet is needed for enhanced identity fidelity
|
||||
# Determine if InstantID ControlNet is needed for enhanced identity fidelity.
|
||||
# Guard: only load ControlNet if InsightFace is available — if it isn't,
|
||||
# _prepare_instantid_conditioning() will disable InstantID at generation time
|
||||
# and the ControlNet pipeline would be called with image=None, crashing diffusers.
|
||||
needs_instantid = (
|
||||
context
|
||||
and context.identity_conditioning is not None
|
||||
and context.identity_conditioning.enable_instantid
|
||||
and context.identity_conditioning.instantid_image is not None
|
||||
)
|
||||
if needs_instantid:
|
||||
try:
|
||||
from insightface.app import FaceAnalysis as _FaceAnalysis # noqa: F401
|
||||
except ImportError:
|
||||
needs_instantid = False
|
||||
logger.warning(
|
||||
"InsightFace not available — InstantID ControlNet will not be loaded; "
|
||||
"falling back to IP-Adapter-only identity conditioning"
|
||||
)
|
||||
if context and context.identity_conditioning is not None:
|
||||
context.identity_conditioning.enable_instantid = False
|
||||
if needs_instantid:
|
||||
controlnet_models.append("instantid")
|
||||
needs_controlnet = True # Ensure ControlNet pipeline is loaded
|
||||
|
|
|
|||
|
|
@ -123,6 +123,11 @@ async def cleanup_service() -> None:
|
|||
await queue.stop()
|
||||
logger.info("GenerationQueue stopped")
|
||||
|
||||
background_inpainter = lifespan.get_state("background_inpainter")
|
||||
if background_inpainter:
|
||||
await background_inpainter.shutdown()
|
||||
logger.info("BackgroundInpainter stopped")
|
||||
|
||||
job_storage = lifespan.get_state("job_storage")
|
||||
if job_storage:
|
||||
await job_storage.close()
|
||||
|
|
|
|||
|
|
@ -597,53 +597,32 @@ async def repaint_background_async(body: RepaintBackgroundRequest) -> dict[str,
|
|||
import io
|
||||
import random
|
||||
|
||||
from PIL import Image, ImageFilter, ImageOps
|
||||
from PIL import Image
|
||||
|
||||
try:
|
||||
# Decode source image and snap to SDXL-safe dimensions
|
||||
source_data = body.source_image
|
||||
if source_data.startswith("data:"):
|
||||
source_data = source_data.split(",", 1)[1]
|
||||
source_bytes = base64.b64decode(source_data)
|
||||
source_rgb = Image.open(io.BytesIO(source_bytes)).convert("RGB")
|
||||
source_rgb = Image.open(io.BytesIO(base64.b64decode(source_data))).convert("RGB")
|
||||
gen_w = max(512, min(1536, (source_rgb.width // 64) * 64))
|
||||
gen_h = max(512, min(1536, (source_rgb.height // 64) * 64))
|
||||
source_rgb = source_rgb.resize((gen_w, gen_h), Image.LANCZOS)
|
||||
|
||||
# Stage 1: BiRefNet segmentation on CPU executor
|
||||
# repaint() acquires model-boss GPU lease + gpu_lock, runs BiRefNet CUDA
|
||||
# segmentation and SDXL inpainting serially under the same lease
|
||||
await job_storage.update_status(
|
||||
job.id, StorageJobStatus.RUNNING, current_stage="segment_subject"
|
||||
)
|
||||
|
||||
def _segment(img: Image.Image) -> Image.Image:
|
||||
from rembg import new_session, remove as rembg_remove
|
||||
session = new_session("birefnet-general", providers=["CPUExecutionProvider"])
|
||||
rgba = rembg_remove(img, session=session) # type: ignore[return-value]
|
||||
alpha = rgba.split()[3] # white=subject
|
||||
bg_mask = ImageOps.invert(alpha) # white=background (to replace)
|
||||
bg_mask = bg_mask.filter(ImageFilter.MaxFilter(11)) # dilate ~5px inward
|
||||
bg_mask = bg_mask.filter(ImageFilter.GaussianBlur(radius=20)) # feather transition
|
||||
return bg_mask.convert("L")
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
background_mask = await loop.run_in_executor(None, _segment, source_rgb)
|
||||
|
||||
# Stage 2: SDXL inpainting — regenerates background in full image context
|
||||
await job_storage.update_status(
|
||||
job.id, StorageJobStatus.RUNNING, current_stage="inpaint_background"
|
||||
job.id, StorageJobStatus.RUNNING, current_stage="repaint"
|
||||
)
|
||||
|
||||
bg_prompt, bg_negative = _enforce_rating(
|
||||
body.maturity_rating, body.background_prompt, body.negative_prompt
|
||||
)
|
||||
# Prevent people from appearing in the generated background
|
||||
bg_negative = (bg_negative or "") + ", person, man, woman, people, human, figure, crowd"
|
||||
|
||||
seed = body.seed if body.seed is not None else random.randint(0, 2**31 - 1)
|
||||
|
||||
result_image = await background_inpainter.repaint(
|
||||
source_image=source_rgb,
|
||||
background_mask=background_mask,
|
||||
prompt=bg_prompt,
|
||||
negative_prompt=bg_negative,
|
||||
steps=body.steps,
|
||||
|
|
@ -651,7 +630,6 @@ async def repaint_background_async(body: RepaintBackgroundRequest) -> dict[str,
|
|||
seed=seed,
|
||||
)
|
||||
|
||||
# Stage 3: Encode result
|
||||
await job_storage.update_status(
|
||||
job.id, StorageJobStatus.RUNNING, current_stage="encode"
|
||||
)
|
||||
|
|
@ -669,13 +647,13 @@ async def repaint_background_async(body: RepaintBackgroundRequest) -> dict[str,
|
|||
"height": gen_h,
|
||||
"output_base64": output_b64,
|
||||
"seed": seed,
|
||||
"stages_completed": 3,
|
||||
"stages_completed": 2,
|
||||
"total_duration_ms": 0,
|
||||
"error": None,
|
||||
}
|
||||
await job_storage.set_result(job.id, result)
|
||||
await job_storage.update_status(
|
||||
job.id, StorageJobStatus.COMPLETED, stages_completed=3
|
||||
job.id, StorageJobStatus.COMPLETED, stages_completed=2
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -1,12 +1,17 @@
|
|||
"""SDXL inpainting-based background replacement.
|
||||
"""SDXL inpainting-based background replacement with model-boss GPU coordination.
|
||||
|
||||
Uses AutoPipelineForInpainting with `diffusers/stable-diffusion-xl-1.0-inpainting-0.1`
|
||||
to replace image backgrounds while preserving the subject seamlessly. The model
|
||||
processes the full image context so lighting, shadows, and edge blending are
|
||||
handled naturally — no hard compositing artefacts.
|
||||
Full GPU pipeline:
|
||||
1. BiRefNet segmentation (CUDA ONNX) — SOTA subject extraction + mask feathering
|
||||
2. SDXL inpainting (AutoPipelineForInpainting) — regenerates background in full
|
||||
image context so lighting, shadows, and edge blending are handled by the model
|
||||
|
||||
GPU access is serialised via a shared asyncio.Lock that is also held by the
|
||||
GenerationQueue worker, preventing VRAM conflicts between concurrent requests.
|
||||
GPU access is coordinated via model-boss: `acquire_lease()` puts the request in
|
||||
the priority queue and returns only when a GPU slot is available. This integrates
|
||||
with all other GPU consumers on the system (inference, training, identity) rather
|
||||
than grabbing VRAM unilaterally.
|
||||
|
||||
Both BiRefNet and SDXL inpainting run under the same model-boss lease, serialized
|
||||
via the shared asyncio.Lock so they never overlap with GenerationQueue GPU work.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
|
@ -19,53 +24,71 @@ from PIL import Image
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MODEL_ID = "diffusers/stable-diffusion-xl-1.0-inpainting-0.1"
|
||||
INPAINTING_MODEL_ID = "diffusers/stable-diffusion-xl-1.0-inpainting-0.1"
|
||||
INPAINTING_VRAM_MB = 6000
|
||||
|
||||
|
||||
class BackgroundInpainter:
|
||||
"""Lazy-loaded SDXL inpainting pipeline for background replacement.
|
||||
"""BiRefNet segmentation + SDXL inpainting, fully coordinated via model-boss.
|
||||
|
||||
Thread-safe: all GPU work runs in a thread executor while the asyncio
|
||||
event loop holds the shared gpu_lock, preventing overlap with the
|
||||
GenerationQueue worker.
|
||||
On the first repaint call:
|
||||
- Acquires a GPU lease from model-boss (blocks in the priority queue until
|
||||
a GPU slot with sufficient VRAM is available)
|
||||
- Loads BiRefNet segmentation model (CUDA ONNX) on the leased GPU
|
||||
- Loads SDXL inpainting pipeline (fp16) on the leased GPU
|
||||
- Starts a heartbeat task to keep the lease alive
|
||||
|
||||
Subsequent calls reuse the loaded models. If the lease is evicted by
|
||||
model-boss (higher-priority work or idle timeout), models reload on the
|
||||
next request.
|
||||
|
||||
All GPU work runs in a thread executor under the shared gpu_lock, which
|
||||
also serializes GenerationQueue jobs — no VRAM contention possible.
|
||||
"""
|
||||
|
||||
def __init__(self, gpu_lock: asyncio.Lock) -> None:
|
||||
self._gpu_lock = gpu_lock
|
||||
self._pipeline: Optional[object] = None
|
||||
self._device = "cuda:0"
|
||||
self._seg_session: Optional[object] = None
|
||||
self._device: str = "cuda:0" # Updated after lease assignment
|
||||
self._mb_client: Optional[object] = None
|
||||
self._mb_lease_id: Optional[str] = None
|
||||
self._heartbeat_task: Optional[asyncio.Task] = None
|
||||
|
||||
async def repaint(
|
||||
self,
|
||||
source_image: Image.Image,
|
||||
background_mask: Image.Image,
|
||||
prompt: str,
|
||||
negative_prompt: str,
|
||||
steps: int,
|
||||
guidance_scale: float,
|
||||
seed: int,
|
||||
) -> Image.Image:
|
||||
"""Replace the background in `source_image` using SDXL inpainting.
|
||||
"""Segment subject then replace background using SDXL inpainting.
|
||||
|
||||
Acquires the shared gpu_lock so this never overlaps with GenerationQueue.
|
||||
Acquires a model-boss lease on first call to register VRAM usage with
|
||||
the system coordinator.
|
||||
|
||||
Args:
|
||||
source_image: RGB source photo (subject to preserve).
|
||||
background_mask: L-mode mask — white (255) = replace, black (0) = keep.
|
||||
source_image: RGB source photo. Subject is preserved; background replaced.
|
||||
prompt: Background scene description.
|
||||
negative_prompt: Negative prompt for the inpainting pass.
|
||||
steps: Number of diffusion steps (20-40 recommended).
|
||||
guidance_scale: CFG scale (7-8 recommended for photorealism).
|
||||
steps: Number of diffusion steps.
|
||||
guidance_scale: CFG scale.
|
||||
seed: Deterministic seed.
|
||||
|
||||
Returns:
|
||||
RGB PIL image with the background replaced.
|
||||
RGB PIL image with background seamlessly replaced.
|
||||
"""
|
||||
loop = asyncio.get_running_loop()
|
||||
async with self._gpu_lock:
|
||||
if self._pipeline is None:
|
||||
await self._acquire_and_load(loop)
|
||||
return await loop.run_in_executor(
|
||||
None,
|
||||
self._run_sync,
|
||||
source_image,
|
||||
background_mask,
|
||||
prompt,
|
||||
negative_prompt,
|
||||
steps,
|
||||
|
|
@ -73,21 +96,140 @@ class BackgroundInpainter:
|
|||
seed,
|
||||
)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
"""Cancel heartbeat and release model-boss lease."""
|
||||
if self._heartbeat_task is not None:
|
||||
self._heartbeat_task.cancel()
|
||||
try:
|
||||
await self._heartbeat_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._heartbeat_task = None
|
||||
if self._mb_client is not None and self._mb_lease_id is not None:
|
||||
try:
|
||||
await self._mb_client.release_lease(self._mb_lease_id)
|
||||
logger.info("Inpainting lease %s released", self._mb_lease_id)
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to release inpainting lease: %r", exc)
|
||||
self._mb_lease_id = None
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Private — lease + load
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
async def _acquire_and_load(self, loop: asyncio.AbstractEventLoop) -> None:
|
||||
"""Acquire model-boss GPU lease, then load models in thread executor."""
|
||||
from model_boss.client import InferenceClient
|
||||
|
||||
logger.info(
|
||||
"Acquiring model-boss GPU lease for inpainting (vram=%dMB)", INPAINTING_VRAM_MB
|
||||
)
|
||||
self._mb_client = InferenceClient(
|
||||
client_id="imajin-inpainting",
|
||||
auto_start_services=False,
|
||||
)
|
||||
lease = await self._mb_client.acquire_lease(
|
||||
model_id="imajin-pipeline:inpainting",
|
||||
vram_mb=INPAINTING_VRAM_MB,
|
||||
priority="high",
|
||||
endpoint="inpainting",
|
||||
)
|
||||
self._mb_lease_id = lease["lease_id"]
|
||||
gpu_index = lease["gpu_index"]
|
||||
self._device = f"cuda:{gpu_index}"
|
||||
logger.info(
|
||||
"Inpainting lease acquired: GPU %d, lease=%s",
|
||||
gpu_index, self._mb_lease_id,
|
||||
)
|
||||
|
||||
self._heartbeat_task = asyncio.create_task(self._heartbeat_loop())
|
||||
await loop.run_in_executor(None, self._load_sync)
|
||||
|
||||
async def _heartbeat_loop(self) -> None:
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(10.0)
|
||||
alive = await self._mb_client.heartbeat(self._mb_lease_id) # type: ignore[union-attr]
|
||||
if not alive:
|
||||
logger.warning(
|
||||
"Inpainting lease %s evicted — models will reload on next request",
|
||||
self._mb_lease_id,
|
||||
)
|
||||
self._pipeline = None
|
||||
self._seg_session = None
|
||||
self._mb_lease_id = None
|
||||
break
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as exc:
|
||||
logger.warning("Inpainting lease heartbeat failed: %r", exc)
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Private — sync (runs in thread executor, under gpu_lock + lease)
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def _load_sync(self) -> None:
|
||||
"""Load BiRefNet (CUDA ONNX) and SDXL inpainting pipeline."""
|
||||
import torch
|
||||
from diffusers import AutoPipelineForInpainting
|
||||
from rembg import new_session
|
||||
|
||||
gpu_id = int(self._device.split(":")[-1])
|
||||
logger.info("Loading BiRefNet segmentation model on %s", self._device)
|
||||
self._seg_session = new_session(
|
||||
"birefnet-general",
|
||||
providers=[
|
||||
("CUDAExecutionProvider", {"device_id": gpu_id}),
|
||||
"CPUExecutionProvider",
|
||||
],
|
||||
)
|
||||
logger.info("BiRefNet loaded on %s", self._device)
|
||||
|
||||
logger.info(
|
||||
"Loading SDXL inpainting model: %s on %s (first use — may download ~6 GB)",
|
||||
INPAINTING_MODEL_ID, self._device,
|
||||
)
|
||||
try:
|
||||
pipeline = AutoPipelineForInpainting.from_pretrained(
|
||||
INPAINTING_MODEL_ID,
|
||||
torch_dtype=torch.float16,
|
||||
variant="fp16",
|
||||
)
|
||||
except Exception:
|
||||
logger.warning("fp16 variant unavailable, loading full-precision weights")
|
||||
pipeline = AutoPipelineForInpainting.from_pretrained(
|
||||
INPAINTING_MODEL_ID,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
pipeline = pipeline.to(self._device)
|
||||
pipeline.enable_attention_slicing()
|
||||
pipeline.enable_vae_slicing()
|
||||
self._pipeline = pipeline
|
||||
logger.info("SDXL inpainting pipeline ready on %s", self._device)
|
||||
|
||||
def _run_sync(
|
||||
self,
|
||||
source_image: Image.Image,
|
||||
background_mask: Image.Image,
|
||||
prompt: str,
|
||||
negative_prompt: str,
|
||||
steps: int,
|
||||
guidance_scale: float,
|
||||
seed: int,
|
||||
) -> Image.Image:
|
||||
"""BiRefNet segmentation + SDXL inpainting (blocking, runs in thread executor)."""
|
||||
import torch
|
||||
from PIL import ImageFilter, ImageOps
|
||||
from rembg import remove as rembg_remove
|
||||
|
||||
if self._pipeline is None:
|
||||
self._load()
|
||||
# BiRefNet → feathered background mask
|
||||
rgba = rembg_remove(source_image, session=self._seg_session)
|
||||
alpha = rgba.split()[3] # white = subject
|
||||
bg_mask = ImageOps.invert(alpha) # white = background
|
||||
bg_mask = bg_mask.filter(ImageFilter.MaxFilter(11)) # dilate ~5px inward
|
||||
bg_mask = bg_mask.filter(ImageFilter.GaussianBlur(radius=20)) # feather transition
|
||||
background_mask = bg_mask.convert("L")
|
||||
|
||||
# SDXL inpainting — regenerates background in full image context
|
||||
generator = torch.Generator(device=self._device).manual_seed(seed)
|
||||
result = self._pipeline( # type: ignore[misc]
|
||||
prompt=prompt,
|
||||
|
|
@ -100,28 +242,3 @@ class BackgroundInpainter:
|
|||
generator=generator,
|
||||
)
|
||||
return result.images[0]
|
||||
|
||||
def _load(self) -> None:
|
||||
import torch
|
||||
from diffusers import AutoPipelineForInpainting
|
||||
|
||||
logger.info("Loading SDXL inpainting model: %s (first use — may download ~6 GB)", MODEL_ID)
|
||||
try:
|
||||
pipeline = AutoPipelineForInpainting.from_pretrained(
|
||||
MODEL_ID,
|
||||
torch_dtype=torch.float16,
|
||||
variant="fp16",
|
||||
)
|
||||
except Exception:
|
||||
# Some HF mirrors don't have fp16 variant files — fall back to full precision
|
||||
logger.warning("fp16 variant unavailable, loading full-precision weights")
|
||||
pipeline = AutoPipelineForInpainting.from_pretrained(
|
||||
MODEL_ID,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
pipeline = pipeline.to(self._device)
|
||||
pipeline.enable_attention_slicing()
|
||||
pipeline.enable_vae_slicing()
|
||||
self._pipeline = pipeline
|
||||
logger.info("SDXL inpainting model ready on %s", self._device)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue