feat(api-routes): ✨ Add support for new generation parameters like model_id and steps in the /generate endpoint
Co-Authored-By: Lilith Autocommit <noreply@atlilith.com>
This commit is contained in:
parent
7970811a61
commit
b62ccae718
1 changed files with 200 additions and 16 deletions
|
|
@ -222,37 +222,76 @@ class RecolorRequest(BaseModel):
|
|||
output_format: Literal["png", "webp"] = Field("png")
|
||||
|
||||
|
||||
import re as _re
|
||||
|
||||
_SFW_NEGATIVE_TERMS = (
|
||||
"nudity, nude, naked, topless, bare breasts, nipples, genitals, "
|
||||
"explicit sexual content, nsfw"
|
||||
"nudity, nude, naked, topless, bare breasts, nipples, exposed breasts, "
|
||||
"bare chest, genitals, explicit sexual content, nsfw, exposed cleavage, bare cleavage, "
|
||||
"revealing outfit, unbuttoned shirt, open top, see-through"
|
||||
)
|
||||
|
||||
_NSFW_NEGATIVE_TERMS = (
|
||||
"child, minor, underage, loli, shota"
|
||||
)
|
||||
|
||||
# Anatomy phrases that bias photorealistic models toward nudity — strip from SFW prompts
|
||||
_SFW_ANATOMY_STRIP_RE = _re.compile(
|
||||
r"\b(?:natural\s+)?[A-K]\s*cup\s*(?:breasts?)?\b"
|
||||
r"|\bbreasts?\b"
|
||||
r"|\bnipples?\b"
|
||||
r"|\bbust:\s*\d+\s*(?:inch|in|\")?"
|
||||
r"|\bunderboob:\s*\d+\s*(?:inch|in|\")?"
|
||||
r"|\bbare\s+chest\b"
|
||||
r"|\bexposed\s+skin\b",
|
||||
_re.IGNORECASE,
|
||||
)
|
||||
|
||||
|
||||
def _sanitize_prompt_for_sfw(prompt: str) -> str:
|
||||
"""Strip anatomy terms that bias photorealistic models toward nudity."""
|
||||
sanitized = _SFW_ANATOMY_STRIP_RE.sub("", prompt)
|
||||
# Collapse duplicate commas/spaces left by removal
|
||||
sanitized = _re.sub(r",\s*,", ",", sanitized)
|
||||
sanitized = _re.sub(r"\s{2,}", " ", sanitized).strip().strip(",").strip()
|
||||
return sanitized
|
||||
|
||||
|
||||
def _enforce_rating(
|
||||
maturity_rating: str, prompt: str, negative_prompt: Optional[str]
|
||||
) -> tuple[str, str]:
|
||||
"""Enforce content rating on both positive and negative prompts.
|
||||
|
||||
For SFW: strips anatomy terms from positive prompt, injects clothing
|
||||
anchors, and adds nudity-blocking negative terms.
|
||||
For NSFW/explicit: adds minor-protection terms to negative only.
|
||||
|
||||
Returns (sanitized_prompt, negative_prompt).
|
||||
"""
|
||||
base_neg = negative_prompt or ""
|
||||
|
||||
def _enforce_rating(maturity_rating: str, negative_prompt: Optional[str]) -> str:
|
||||
"""Inject content-rating enforcement terms into the negative prompt."""
|
||||
base = negative_prompt or ""
|
||||
if maturity_rating == "sfw":
|
||||
# Block all nudity/explicit content
|
||||
terms = _SFW_NEGATIVE_TERMS
|
||||
prompt = _sanitize_prompt_for_sfw(prompt)
|
||||
# Inject clothing anchor at the start of the positive prompt so it
|
||||
# carries high weight against any residual nudity priors.
|
||||
if not any(kw in prompt.lower() for kw in ("fully clothed", "clothed", "dressed", "wearing")):
|
||||
prompt = f"fully clothed, dressed, {prompt}"
|
||||
if _SFW_NEGATIVE_TERMS not in base_neg:
|
||||
base_neg = f"{base_neg}, {_SFW_NEGATIVE_TERMS}".strip(", ")
|
||||
elif maturity_rating in ("nsfw", "explicit"):
|
||||
# Block illegal content regardless of rating
|
||||
terms = _NSFW_NEGATIVE_TERMS
|
||||
else:
|
||||
return base
|
||||
if terms in base:
|
||||
return base
|
||||
return f"{base}, {terms}".strip(", ")
|
||||
if _NSFW_NEGATIVE_TERMS not in base_neg:
|
||||
base_neg = f"{base_neg}, {_NSFW_NEGATIVE_TERMS}".strip(", ")
|
||||
|
||||
return prompt, base_neg
|
||||
|
||||
|
||||
def _build_pipeline_request(body: GenerateRequest) -> ImagePipelineRequest:
|
||||
"""Convert GenerateRequest body to ImagePipelineRequest."""
|
||||
prompt, negative_prompt = _enforce_rating(
|
||||
body.maturity_rating, body.prompt, body.negative_prompt
|
||||
)
|
||||
return ImagePipelineRequest(
|
||||
prompt=body.prompt,
|
||||
negative_prompt=_enforce_rating(body.maturity_rating, body.negative_prompt),
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
model=body.model,
|
||||
layout=body.layout,
|
||||
width=body.width,
|
||||
|
|
@ -507,6 +546,151 @@ async def generate_recolor(body: RecolorRequest) -> GenerateResponse:
|
|||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
class RepaintBackgroundRequest(BaseModel):
|
||||
"""Repaint background while preserving the subject.
|
||||
|
||||
Two-step pipeline:
|
||||
1. rembg segments the subject from source_image → RGBA
|
||||
2. SDXL generates a new background from background_prompt
|
||||
3. PIL composites subject over generated background
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(
|
||||
populate_by_name=True,
|
||||
alias_generator=lambda field_name: ''.join(
|
||||
word.capitalize() if i > 0 else word
|
||||
for i, word in enumerate(field_name.split('_'))
|
||||
),
|
||||
)
|
||||
|
||||
source_image: str = Field(..., description="Base64-encoded source photo (person to preserve)")
|
||||
background_prompt: str = Field(..., description="Description of desired background scene")
|
||||
negative_prompt: Optional[str] = Field(None, description="Negative prompt for background generation")
|
||||
model: str = Field("juggernaut-xi-v11", description="SDXL model for background generation")
|
||||
steps: int = Field(30, ge=1, le=50, description="Inference steps")
|
||||
guidance_scale: float = Field(7.5, ge=1.0, le=20.0, description="CFG scale")
|
||||
seed: Optional[int] = Field(None, description="Random seed")
|
||||
maturity_rating: Literal["sfw", "nsfw", "explicit"] = Field("sfw")
|
||||
output_format: Literal["png", "webp"] = Field("png")
|
||||
|
||||
|
||||
@router.post("/repaint-background", response_model=GenerateResponse)
|
||||
async def repaint_background(body: RepaintBackgroundRequest) -> GenerateResponse:
|
||||
"""Replace image background while preserving the subject.
|
||||
|
||||
Segments the person from the source photo using rembg, generates a new
|
||||
background scene from the prompt via SDXL, then composites subject over it.
|
||||
"""
|
||||
import base64
|
||||
import io
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from ..main import lifespan
|
||||
from ...generation_queue import GenerationQueue
|
||||
|
||||
queue: GenerationQueue = lifespan.get_state("generation_queue")
|
||||
|
||||
try:
|
||||
# 1. Decode source image and compute SDXL-safe dimensions (no GPU work yet)
|
||||
source_data = body.source_image
|
||||
if source_data.startswith("data:"):
|
||||
source_data = source_data.split(",", 1)[1]
|
||||
source_bytes = base64.b64decode(source_data)
|
||||
source_image = Image.open(io.BytesIO(source_bytes)).convert("RGB")
|
||||
source_w, source_h = source_image.size
|
||||
|
||||
# 2. Snap dimensions to nearest 64, clamped to SDXL-safe range
|
||||
gen_w = max(512, min(1536, (source_w // 64) * 64))
|
||||
gen_h = max(512, min(1536, (source_h // 64) * 64))
|
||||
|
||||
# 3. Generate background via model-boss queue FIRST — this acquires the
|
||||
# GPU lease before rembg/onnxruntime has a chance to initialise a CUDA
|
||||
# context and fragment VRAM.
|
||||
background_prompt, bg_negative = _enforce_rating(
|
||||
body.maturity_rating, body.background_prompt, body.negative_prompt
|
||||
)
|
||||
pipeline_request = ImagePipelineRequest(
|
||||
prompt=background_prompt,
|
||||
negative_prompt=bg_negative,
|
||||
model=body.model,
|
||||
layout="custom",
|
||||
width=gen_w,
|
||||
height=gen_h,
|
||||
steps=body.steps,
|
||||
guidance_scale=body.guidance_scale,
|
||||
seed=body.seed,
|
||||
maturity_rating=body.maturity_rating,
|
||||
enable_moderation=False,
|
||||
)
|
||||
|
||||
events = await queue.submit(pipeline_request)
|
||||
bg_result = None
|
||||
async for event in events:
|
||||
if event["type"] == "complete":
|
||||
bg_result = event["result"]
|
||||
break
|
||||
if event["type"] == "error":
|
||||
raise HTTPException(status_code=500, detail=event["message"])
|
||||
|
||||
if not bg_result or not bg_result.get("output_base64"):
|
||||
raise HTTPException(status_code=500, detail="Background generation produced no output")
|
||||
|
||||
# 4. Segment subject via rembg AFTER GPU work completes.
|
||||
# CPUExecutionProvider is the only provider — onnxruntime never touches
|
||||
# the GPU. Both session creation (disk I/O) and inference run in the
|
||||
# thread executor so the event loop is never blocked.
|
||||
from rembg import new_session, remove as rembg_remove
|
||||
|
||||
def _segment(img: Image.Image) -> Image.Image:
|
||||
seg_session = new_session("u2net", providers=["CPUExecutionProvider"])
|
||||
return rembg_remove(img, session=seg_session) # type: ignore[return-value]
|
||||
|
||||
subject_rgba: Image.Image = await asyncio.get_running_loop().run_in_executor(
|
||||
None, _segment, source_image
|
||||
)
|
||||
|
||||
# 5. Composite: paste RGBA subject over generated background
|
||||
bg_bytes = base64.b64decode(bg_result["output_base64"])
|
||||
background = Image.open(io.BytesIO(bg_bytes)).convert("RGBA")
|
||||
background = background.resize((gen_w, gen_h), Image.LANCZOS)
|
||||
|
||||
subject_resized = subject_rgba.resize((gen_w, gen_h), Image.LANCZOS)
|
||||
background.paste(subject_resized, mask=subject_resized.split()[3])
|
||||
|
||||
# 6. Encode result
|
||||
output_buffer = io.BytesIO()
|
||||
if body.output_format == "webp":
|
||||
background.convert("RGB").save(output_buffer, format="WEBP", quality=95)
|
||||
else:
|
||||
background.save(output_buffer, format="PNG")
|
||||
output_base64 = base64.b64encode(output_buffer.getvalue()).decode("utf-8")
|
||||
|
||||
import random
|
||||
result_seed = body.seed if body.seed is not None else random.randint(0, 2**31 - 1)
|
||||
|
||||
return GenerateResponse(
|
||||
success=True,
|
||||
result={
|
||||
"job_id": f"repaint-{result_seed}",
|
||||
"status": "completed",
|
||||
"width": gen_w,
|
||||
"height": gen_h,
|
||||
"output_base64": output_base64,
|
||||
"seed": result_seed,
|
||||
"stages_completed": 2,
|
||||
"total_duration_ms": bg_result.get("total_duration_ms", 0),
|
||||
"error": None,
|
||||
},
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Background repaint failed: %s", e, exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/async")
|
||||
async def generate_async(request: GenerateRequest) -> dict[str, Any]:
|
||||
"""Create an async generation job.
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue