feat(api-routes): Add support for new generation parameters like model_id and steps in the /generate endpoint

Co-Authored-By: Lilith Autocommit <noreply@atlilith.com>
This commit is contained in:
Claude Code 2026-03-30 21:36:59 -07:00
parent 7970811a61
commit b62ccae718

View file

@ -222,37 +222,76 @@ class RecolorRequest(BaseModel):
output_format: Literal["png", "webp"] = Field("png")
import re as _re
_SFW_NEGATIVE_TERMS = (
"nudity, nude, naked, topless, bare breasts, nipples, genitals, "
"explicit sexual content, nsfw"
"nudity, nude, naked, topless, bare breasts, nipples, exposed breasts, "
"bare chest, genitals, explicit sexual content, nsfw, exposed cleavage, bare cleavage, "
"revealing outfit, unbuttoned shirt, open top, see-through"
)
_NSFW_NEGATIVE_TERMS = (
"child, minor, underage, loli, shota"
)
# Anatomy phrases that bias photorealistic models toward nudity — strip from SFW prompts
_SFW_ANATOMY_STRIP_RE = _re.compile(
r"\b(?:natural\s+)?[A-K]\s*cup\s*(?:breasts?)?\b"
r"|\bbreasts?\b"
r"|\bnipples?\b"
r"|\bbust:\s*\d+\s*(?:inch|in|\")?"
r"|\bunderboob:\s*\d+\s*(?:inch|in|\")?"
r"|\bbare\s+chest\b"
r"|\bexposed\s+skin\b",
_re.IGNORECASE,
)
def _sanitize_prompt_for_sfw(prompt: str) -> str:
"""Strip anatomy terms that bias photorealistic models toward nudity."""
sanitized = _SFW_ANATOMY_STRIP_RE.sub("", prompt)
# Collapse duplicate commas/spaces left by removal
sanitized = _re.sub(r",\s*,", ",", sanitized)
sanitized = _re.sub(r"\s{2,}", " ", sanitized).strip().strip(",").strip()
return sanitized
def _enforce_rating(
maturity_rating: str, prompt: str, negative_prompt: Optional[str]
) -> tuple[str, str]:
"""Enforce content rating on both positive and negative prompts.
For SFW: strips anatomy terms from positive prompt, injects clothing
anchors, and adds nudity-blocking negative terms.
For NSFW/explicit: adds minor-protection terms to negative only.
Returns (sanitized_prompt, negative_prompt).
"""
base_neg = negative_prompt or ""
def _enforce_rating(maturity_rating: str, negative_prompt: Optional[str]) -> str:
"""Inject content-rating enforcement terms into the negative prompt."""
base = negative_prompt or ""
if maturity_rating == "sfw":
# Block all nudity/explicit content
terms = _SFW_NEGATIVE_TERMS
prompt = _sanitize_prompt_for_sfw(prompt)
# Inject clothing anchor at the start of the positive prompt so it
# carries high weight against any residual nudity priors.
if not any(kw in prompt.lower() for kw in ("fully clothed", "clothed", "dressed", "wearing")):
prompt = f"fully clothed, dressed, {prompt}"
if _SFW_NEGATIVE_TERMS not in base_neg:
base_neg = f"{base_neg}, {_SFW_NEGATIVE_TERMS}".strip(", ")
elif maturity_rating in ("nsfw", "explicit"):
# Block illegal content regardless of rating
terms = _NSFW_NEGATIVE_TERMS
else:
return base
if terms in base:
return base
return f"{base}, {terms}".strip(", ")
if _NSFW_NEGATIVE_TERMS not in base_neg:
base_neg = f"{base_neg}, {_NSFW_NEGATIVE_TERMS}".strip(", ")
return prompt, base_neg
def _build_pipeline_request(body: GenerateRequest) -> ImagePipelineRequest:
"""Convert GenerateRequest body to ImagePipelineRequest."""
prompt, negative_prompt = _enforce_rating(
body.maturity_rating, body.prompt, body.negative_prompt
)
return ImagePipelineRequest(
prompt=body.prompt,
negative_prompt=_enforce_rating(body.maturity_rating, body.negative_prompt),
prompt=prompt,
negative_prompt=negative_prompt,
model=body.model,
layout=body.layout,
width=body.width,
@ -507,6 +546,151 @@ async def generate_recolor(body: RecolorRequest) -> GenerateResponse:
raise HTTPException(status_code=500, detail=str(e))
class RepaintBackgroundRequest(BaseModel):
"""Repaint background while preserving the subject.
Two-step pipeline:
1. rembg segments the subject from source_image RGBA
2. SDXL generates a new background from background_prompt
3. PIL composites subject over generated background
"""
model_config = ConfigDict(
populate_by_name=True,
alias_generator=lambda field_name: ''.join(
word.capitalize() if i > 0 else word
for i, word in enumerate(field_name.split('_'))
),
)
source_image: str = Field(..., description="Base64-encoded source photo (person to preserve)")
background_prompt: str = Field(..., description="Description of desired background scene")
negative_prompt: Optional[str] = Field(None, description="Negative prompt for background generation")
model: str = Field("juggernaut-xi-v11", description="SDXL model for background generation")
steps: int = Field(30, ge=1, le=50, description="Inference steps")
guidance_scale: float = Field(7.5, ge=1.0, le=20.0, description="CFG scale")
seed: Optional[int] = Field(None, description="Random seed")
maturity_rating: Literal["sfw", "nsfw", "explicit"] = Field("sfw")
output_format: Literal["png", "webp"] = Field("png")
@router.post("/repaint-background", response_model=GenerateResponse)
async def repaint_background(body: RepaintBackgroundRequest) -> GenerateResponse:
"""Replace image background while preserving the subject.
Segments the person from the source photo using rembg, generates a new
background scene from the prompt via SDXL, then composites subject over it.
"""
import base64
import io
from PIL import Image
from ..main import lifespan
from ...generation_queue import GenerationQueue
queue: GenerationQueue = lifespan.get_state("generation_queue")
try:
# 1. Decode source image and compute SDXL-safe dimensions (no GPU work yet)
source_data = body.source_image
if source_data.startswith("data:"):
source_data = source_data.split(",", 1)[1]
source_bytes = base64.b64decode(source_data)
source_image = Image.open(io.BytesIO(source_bytes)).convert("RGB")
source_w, source_h = source_image.size
# 2. Snap dimensions to nearest 64, clamped to SDXL-safe range
gen_w = max(512, min(1536, (source_w // 64) * 64))
gen_h = max(512, min(1536, (source_h // 64) * 64))
# 3. Generate background via model-boss queue FIRST — this acquires the
# GPU lease before rembg/onnxruntime has a chance to initialise a CUDA
# context and fragment VRAM.
background_prompt, bg_negative = _enforce_rating(
body.maturity_rating, body.background_prompt, body.negative_prompt
)
pipeline_request = ImagePipelineRequest(
prompt=background_prompt,
negative_prompt=bg_negative,
model=body.model,
layout="custom",
width=gen_w,
height=gen_h,
steps=body.steps,
guidance_scale=body.guidance_scale,
seed=body.seed,
maturity_rating=body.maturity_rating,
enable_moderation=False,
)
events = await queue.submit(pipeline_request)
bg_result = None
async for event in events:
if event["type"] == "complete":
bg_result = event["result"]
break
if event["type"] == "error":
raise HTTPException(status_code=500, detail=event["message"])
if not bg_result or not bg_result.get("output_base64"):
raise HTTPException(status_code=500, detail="Background generation produced no output")
# 4. Segment subject via rembg AFTER GPU work completes.
# CPUExecutionProvider is the only provider — onnxruntime never touches
# the GPU. Both session creation (disk I/O) and inference run in the
# thread executor so the event loop is never blocked.
from rembg import new_session, remove as rembg_remove
def _segment(img: Image.Image) -> Image.Image:
seg_session = new_session("u2net", providers=["CPUExecutionProvider"])
return rembg_remove(img, session=seg_session) # type: ignore[return-value]
subject_rgba: Image.Image = await asyncio.get_running_loop().run_in_executor(
None, _segment, source_image
)
# 5. Composite: paste RGBA subject over generated background
bg_bytes = base64.b64decode(bg_result["output_base64"])
background = Image.open(io.BytesIO(bg_bytes)).convert("RGBA")
background = background.resize((gen_w, gen_h), Image.LANCZOS)
subject_resized = subject_rgba.resize((gen_w, gen_h), Image.LANCZOS)
background.paste(subject_resized, mask=subject_resized.split()[3])
# 6. Encode result
output_buffer = io.BytesIO()
if body.output_format == "webp":
background.convert("RGB").save(output_buffer, format="WEBP", quality=95)
else:
background.save(output_buffer, format="PNG")
output_base64 = base64.b64encode(output_buffer.getvalue()).decode("utf-8")
import random
result_seed = body.seed if body.seed is not None else random.randint(0, 2**31 - 1)
return GenerateResponse(
success=True,
result={
"job_id": f"repaint-{result_seed}",
"status": "completed",
"width": gen_w,
"height": gen_h,
"output_base64": output_base64,
"seed": result_seed,
"stages_completed": 2,
"total_duration_ms": bg_result.get("total_duration_ms", 0),
"error": None,
},
)
except HTTPException:
raise
except Exception as e:
logger.error("Background repaint failed: %s", e, exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@router.post("/async")
async def generate_async(request: GenerateRequest) -> dict[str, Any]:
"""Create an async generation job.