198 lines
7.3 KiB
Python
198 lines
7.3 KiB
Python
"""Repaint command — background replacement via SDXL inpainting.
|
||
|
||
Usage:
|
||
./run repaint --source photo.jpg --prompt "luxury hotel suite, city view"
|
||
./run repaint --source photo.jpg --prompt "hotel suite" --count 4 --out ./results/
|
||
./run repaint --source photo.jpg --prompt "..." --seed 42 --count 3
|
||
"""
|
||
|
||
import argparse
|
||
import base64
|
||
import io
|
||
import json
|
||
import sys
|
||
import time
|
||
from pathlib import Path
|
||
from typing import Optional
|
||
|
||
import requests
|
||
|
||
|
||
def _encode_image(path: Path) -> str:
|
||
"""Read image file and return raw base64 string."""
|
||
return base64.b64encode(path.read_bytes()).decode()
|
||
|
||
|
||
def _submit_job(
|
||
url: str,
|
||
source_b64: str,
|
||
background_prompt: str,
|
||
negative_prompt: Optional[str],
|
||
steps: int,
|
||
guidance_scale: float,
|
||
seed: int,
|
||
rating: str,
|
||
) -> str:
|
||
payload = {
|
||
"sourceImage": source_b64,
|
||
"backgroundPrompt": background_prompt,
|
||
"negativePrompt": negative_prompt,
|
||
"steps": steps,
|
||
"guidanceScale": guidance_scale,
|
||
"seed": seed,
|
||
"maturityRating": rating,
|
||
}
|
||
resp = requests.post(f"{url}/generate/repaint-background/async", json=payload, timeout=30)
|
||
resp.raise_for_status()
|
||
data = resp.json()
|
||
if not data.get("success") or not data.get("jobId"):
|
||
raise RuntimeError(f"Submit failed: {data}")
|
||
return data["jobId"]
|
||
|
||
|
||
def _poll_jobs(url: str, job_ids: list[str], interval: float = 3.0) -> dict[str, dict]:
|
||
"""Poll until all jobs are terminal. Returns {job_id: result_data}."""
|
||
pending = set(job_ids)
|
||
results: dict[str, dict] = {}
|
||
|
||
while pending:
|
||
time.sleep(interval)
|
||
for job_id in list(pending):
|
||
resp = requests.get(f"{url}/jobs/{job_id}", timeout=10)
|
||
resp.raise_for_status()
|
||
data = resp.json()
|
||
status = data.get("status")
|
||
if status == "completed":
|
||
result_resp = requests.get(f"{url}/jobs/{job_id}/result", timeout=30)
|
||
result_resp.raise_for_status()
|
||
result_data = result_resp.json()
|
||
results[job_id] = result_data
|
||
pending.discard(job_id)
|
||
print(f" ✓ {job_id[:8]} done")
|
||
elif status == "failed":
|
||
results[job_id] = {"error": data.get("error", "failed")}
|
||
pending.discard(job_id)
|
||
print(f" ✗ {job_id[:8]} failed: {data.get('error', '?')}", file=sys.stderr)
|
||
|
||
return results
|
||
|
||
|
||
def repaint_command(args: list[str], workspace_root: Path) -> int:
|
||
parser = argparse.ArgumentParser(
|
||
prog="./run repaint",
|
||
description="Replace image background using SDXL inpainting (BiRefNet segmentation)",
|
||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||
epilog="""
|
||
Examples:
|
||
# Single repaint
|
||
./run repaint --source photo.jpg --prompt "luxury hotel suite, city skyline"
|
||
|
||
# Batch of 4 variants from the same source
|
||
./run repaint --source photo.jpg --prompt "hotel suite" --count 4 --out ./results/
|
||
|
||
# Deterministic batch (seed+0, seed+1, ...)
|
||
./run repaint --source photo.jpg --prompt "hotel suite" --seed 42 --count 3
|
||
|
||
# NSFW rating (removes SFW clothing-lock from prompt)
|
||
./run repaint --source photo.jpg --prompt "hotel suite" --rating nsfw
|
||
""",
|
||
)
|
||
parser.add_argument("--source", "-s", required=True, type=Path, help="Source photo path")
|
||
parser.add_argument("--prompt", "-p", required=True, help="Background scene description")
|
||
parser.add_argument("--negative", "-n", default=None, help="Negative prompt")
|
||
parser.add_argument("--count", "-c", type=int, default=1, help="Number of variants (default: 1)")
|
||
parser.add_argument("--seed", type=int, default=None, help="Starting seed (increments per variant)")
|
||
parser.add_argument("--steps", type=int, default=35, help="Inference steps (default: 35)")
|
||
parser.add_argument("--guidance", type=float, default=7.5, help="CFG guidance scale (default: 7.5)")
|
||
parser.add_argument("--rating", choices=["sfw", "nsfw", "explicit"], default="nsfw", help="Content rating (default: nsfw)")
|
||
parser.add_argument("--out", "-o", type=Path, default=None, help="Output directory (or file if count=1)")
|
||
parser.add_argument("--url", default="http://localhost:8002", help="Diffusion service URL")
|
||
|
||
parsed = parser.parse_args(args)
|
||
|
||
source_path = parsed.source.expanduser().resolve()
|
||
if not source_path.exists():
|
||
print(f"Source photo not found: {source_path}", file=sys.stderr)
|
||
return 1
|
||
|
||
# Check service health
|
||
try:
|
||
requests.get(f"{parsed.url}/health", timeout=5).raise_for_status()
|
||
except Exception:
|
||
print(f"Diffusion service not reachable at {parsed.url}", file=sys.stderr)
|
||
print("Start with: ./run dev diffusion", file=sys.stderr)
|
||
return 1
|
||
|
||
# Determine output paths
|
||
out_path = (parsed.out or Path(".")).expanduser().resolve()
|
||
if parsed.count == 1 and out_path.suffix in (".png", ".jpg", ".webp"):
|
||
out_dir = out_path.parent
|
||
out_template = None
|
||
single_out = out_path
|
||
else:
|
||
out_dir = out_path if out_path.suffix == "" else out_path.parent
|
||
out_template = source_path.stem + "_repaint_{n}.png"
|
||
single_out = None
|
||
out_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
# Generate seeds
|
||
import random
|
||
base_seed = parsed.seed if parsed.seed is not None else random.randint(0, 2**31 - 1)
|
||
seeds = [base_seed + i for i in range(parsed.count)]
|
||
|
||
print(f"Repainting {source_path.name} × {parsed.count}")
|
||
print(f" Prompt: {parsed.prompt[:80]}{'...' if len(parsed.prompt) > 80 else ''}")
|
||
print(f" Seeds: {seeds[:5]}{'...' if len(seeds) > 5 else ''}")
|
||
print()
|
||
|
||
source_b64 = _encode_image(source_path)
|
||
|
||
# Submit all jobs
|
||
job_ids: list[str] = []
|
||
for seed in seeds:
|
||
try:
|
||
job_id = _submit_job(
|
||
parsed.url, source_b64, parsed.prompt, parsed.negative,
|
||
parsed.steps, parsed.guidance, seed, parsed.rating,
|
||
)
|
||
job_ids.append(job_id)
|
||
print(f" → {job_id[:8]} seed={seed}")
|
||
except Exception as e:
|
||
print(f" Submit failed (seed={seed}): {e}", file=sys.stderr)
|
||
|
||
if not job_ids:
|
||
print("All submissions failed.", file=sys.stderr)
|
||
return 1
|
||
|
||
print(f"\nPolling {len(job_ids)} job(s)...")
|
||
results = _poll_jobs(parsed.url, job_ids)
|
||
|
||
# Save results
|
||
saved = 0
|
||
for idx, (job_id, result) in enumerate(results.items()):
|
||
if "error" in result and "output_base64" not in str(result):
|
||
continue
|
||
|
||
# Navigate into nested result structure
|
||
r = result.get("result", result)
|
||
b64 = r.get("output_base64", "")
|
||
if not b64:
|
||
continue
|
||
|
||
if single_out and idx == 0:
|
||
out_file = single_out
|
||
else:
|
||
out_file = out_dir / (out_template or f"{source_path.stem}_repaint_{idx+1:02d}.png").format(n=idx + 1)
|
||
|
||
out_file.write_bytes(base64.b64decode(b64))
|
||
w, h = r.get("width", "?"), r.get("height", "?")
|
||
seed_used = r.get("seed", seeds[idx] if idx < len(seeds) else "?")
|
||
print(f" Saved {out_file.name} ({w}×{h}, seed={seed_used})")
|
||
saved += 1
|
||
|
||
print(f"\n{saved}/{len(job_ids)} images saved to {out_dir}")
|
||
return 0 if saved > 0 else 1
|
||
|
||
|
||
def register_repaint_command(runner) -> None:
|
||
runner.register_command("repaint", repaint_command, "Replace image background via SDXL inpainting")
|