imajin/scripts/run/repaint_command.py

199 lines
7.3 KiB
Python
Raw Normal View History

"""Repaint command — background replacement via SDXL inpainting.
Usage:
./run repaint --source photo.jpg --prompt "luxury hotel suite, city view"
./run repaint --source photo.jpg --prompt "hotel suite" --count 4 --out ./results/
./run repaint --source photo.jpg --prompt "..." --seed 42 --count 3
"""
import argparse
import base64
import io
import json
import sys
import time
from pathlib import Path
from typing import Optional
import requests
def _encode_image(path: Path) -> str:
"""Read image file and return raw base64 string."""
return base64.b64encode(path.read_bytes()).decode()
def _submit_job(
url: str,
source_b64: str,
background_prompt: str,
negative_prompt: Optional[str],
steps: int,
guidance_scale: float,
seed: int,
rating: str,
) -> str:
payload = {
"sourceImage": source_b64,
"backgroundPrompt": background_prompt,
"negativePrompt": negative_prompt,
"steps": steps,
"guidanceScale": guidance_scale,
"seed": seed,
"maturityRating": rating,
}
resp = requests.post(f"{url}/generate/repaint-background/async", json=payload, timeout=30)
resp.raise_for_status()
data = resp.json()
if not data.get("success") or not data.get("jobId"):
raise RuntimeError(f"Submit failed: {data}")
return data["jobId"]
def _poll_jobs(url: str, job_ids: list[str], interval: float = 3.0) -> dict[str, dict]:
"""Poll until all jobs are terminal. Returns {job_id: result_data}."""
pending = set(job_ids)
results: dict[str, dict] = {}
while pending:
time.sleep(interval)
for job_id in list(pending):
resp = requests.get(f"{url}/jobs/{job_id}", timeout=10)
resp.raise_for_status()
data = resp.json()
status = data.get("status")
if status == "completed":
result_resp = requests.get(f"{url}/jobs/{job_id}/result", timeout=30)
result_resp.raise_for_status()
result_data = result_resp.json()
results[job_id] = result_data
pending.discard(job_id)
print(f"{job_id[:8]} done")
elif status == "failed":
results[job_id] = {"error": data.get("error", "failed")}
pending.discard(job_id)
print(f"{job_id[:8]} failed: {data.get('error', '?')}", file=sys.stderr)
return results
def repaint_command(args: list[str], workspace_root: Path) -> int:
parser = argparse.ArgumentParser(
prog="./run repaint",
description="Replace image background using SDXL inpainting (BiRefNet segmentation)",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Single repaint
./run repaint --source photo.jpg --prompt "luxury hotel suite, city skyline"
# Batch of 4 variants from the same source
./run repaint --source photo.jpg --prompt "hotel suite" --count 4 --out ./results/
# Deterministic batch (seed+0, seed+1, ...)
./run repaint --source photo.jpg --prompt "hotel suite" --seed 42 --count 3
# NSFW rating (removes SFW clothing-lock from prompt)
./run repaint --source photo.jpg --prompt "hotel suite" --rating nsfw
""",
)
parser.add_argument("--source", "-s", required=True, type=Path, help="Source photo path")
parser.add_argument("--prompt", "-p", required=True, help="Background scene description")
parser.add_argument("--negative", "-n", default=None, help="Negative prompt")
parser.add_argument("--count", "-c", type=int, default=1, help="Number of variants (default: 1)")
parser.add_argument("--seed", type=int, default=None, help="Starting seed (increments per variant)")
parser.add_argument("--steps", type=int, default=35, help="Inference steps (default: 35)")
parser.add_argument("--guidance", type=float, default=7.5, help="CFG guidance scale (default: 7.5)")
parser.add_argument("--rating", choices=["sfw", "nsfw", "explicit"], default="nsfw", help="Content rating (default: nsfw)")
parser.add_argument("--out", "-o", type=Path, default=None, help="Output directory (or file if count=1)")
parser.add_argument("--url", default="http://localhost:8002", help="Diffusion service URL")
parsed = parser.parse_args(args)
source_path = parsed.source.expanduser().resolve()
if not source_path.exists():
print(f"Source photo not found: {source_path}", file=sys.stderr)
return 1
# Check service health
try:
requests.get(f"{parsed.url}/health", timeout=5).raise_for_status()
except Exception:
print(f"Diffusion service not reachable at {parsed.url}", file=sys.stderr)
print("Start with: ./run dev diffusion", file=sys.stderr)
return 1
# Determine output paths
out_path = (parsed.out or Path(".")).expanduser().resolve()
if parsed.count == 1 and out_path.suffix in (".png", ".jpg", ".webp"):
out_dir = out_path.parent
out_template = None
single_out = out_path
else:
out_dir = out_path if out_path.suffix == "" else out_path.parent
out_template = source_path.stem + "_repaint_{n}.png"
single_out = None
out_dir.mkdir(parents=True, exist_ok=True)
# Generate seeds
import random
base_seed = parsed.seed if parsed.seed is not None else random.randint(0, 2**31 - 1)
seeds = [base_seed + i for i in range(parsed.count)]
print(f"Repainting {source_path.name} × {parsed.count}")
print(f" Prompt: {parsed.prompt[:80]}{'...' if len(parsed.prompt) > 80 else ''}")
print(f" Seeds: {seeds[:5]}{'...' if len(seeds) > 5 else ''}")
print()
source_b64 = _encode_image(source_path)
# Submit all jobs
job_ids: list[str] = []
for seed in seeds:
try:
job_id = _submit_job(
parsed.url, source_b64, parsed.prompt, parsed.negative,
parsed.steps, parsed.guidance, seed, parsed.rating,
)
job_ids.append(job_id)
print(f"{job_id[:8]} seed={seed}")
except Exception as e:
print(f" Submit failed (seed={seed}): {e}", file=sys.stderr)
if not job_ids:
print("All submissions failed.", file=sys.stderr)
return 1
print(f"\nPolling {len(job_ids)} job(s)...")
results = _poll_jobs(parsed.url, job_ids)
# Save results
saved = 0
for idx, (job_id, result) in enumerate(results.items()):
if "error" in result and "output_base64" not in str(result):
continue
# Navigate into nested result structure
r = result.get("result", result)
b64 = r.get("output_base64", "")
if not b64:
continue
if single_out and idx == 0:
out_file = single_out
else:
out_file = out_dir / (out_template or f"{source_path.stem}_repaint_{idx+1:02d}.png").format(n=idx + 1)
out_file.write_bytes(base64.b64decode(b64))
w, h = r.get("width", "?"), r.get("height", "?")
seed_used = r.get("seed", seeds[idx] if idx < len(seeds) else "?")
print(f" Saved {out_file.name} ({w}×{h}, seed={seed_used})")
saved += 1
print(f"\n{saved}/{len(job_ids)} images saved to {out_dir}")
return 0 if saved > 0 else 1
def register_repaint_command(runner) -> None:
runner.register_command("repaint", repaint_command, "Replace image background via SDXL inpainting")