199 lines
7.3 KiB
Python
199 lines
7.3 KiB
Python
|
|
"""Repaint command — background replacement via SDXL inpainting.
|
|||
|
|
|
|||
|
|
Usage:
|
|||
|
|
./run repaint --source photo.jpg --prompt "luxury hotel suite, city view"
|
|||
|
|
./run repaint --source photo.jpg --prompt "hotel suite" --count 4 --out ./results/
|
|||
|
|
./run repaint --source photo.jpg --prompt "..." --seed 42 --count 3
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import argparse
|
|||
|
|
import base64
|
|||
|
|
import io
|
|||
|
|
import json
|
|||
|
|
import sys
|
|||
|
|
import time
|
|||
|
|
from pathlib import Path
|
|||
|
|
from typing import Optional
|
|||
|
|
|
|||
|
|
import requests
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _encode_image(path: Path) -> str:
|
|||
|
|
"""Read image file and return raw base64 string."""
|
|||
|
|
return base64.b64encode(path.read_bytes()).decode()
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _submit_job(
|
|||
|
|
url: str,
|
|||
|
|
source_b64: str,
|
|||
|
|
background_prompt: str,
|
|||
|
|
negative_prompt: Optional[str],
|
|||
|
|
steps: int,
|
|||
|
|
guidance_scale: float,
|
|||
|
|
seed: int,
|
|||
|
|
rating: str,
|
|||
|
|
) -> str:
|
|||
|
|
payload = {
|
|||
|
|
"sourceImage": source_b64,
|
|||
|
|
"backgroundPrompt": background_prompt,
|
|||
|
|
"negativePrompt": negative_prompt,
|
|||
|
|
"steps": steps,
|
|||
|
|
"guidanceScale": guidance_scale,
|
|||
|
|
"seed": seed,
|
|||
|
|
"maturityRating": rating,
|
|||
|
|
}
|
|||
|
|
resp = requests.post(f"{url}/generate/repaint-background/async", json=payload, timeout=30)
|
|||
|
|
resp.raise_for_status()
|
|||
|
|
data = resp.json()
|
|||
|
|
if not data.get("success") or not data.get("jobId"):
|
|||
|
|
raise RuntimeError(f"Submit failed: {data}")
|
|||
|
|
return data["jobId"]
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _poll_jobs(url: str, job_ids: list[str], interval: float = 3.0) -> dict[str, dict]:
|
|||
|
|
"""Poll until all jobs are terminal. Returns {job_id: result_data}."""
|
|||
|
|
pending = set(job_ids)
|
|||
|
|
results: dict[str, dict] = {}
|
|||
|
|
|
|||
|
|
while pending:
|
|||
|
|
time.sleep(interval)
|
|||
|
|
for job_id in list(pending):
|
|||
|
|
resp = requests.get(f"{url}/jobs/{job_id}", timeout=10)
|
|||
|
|
resp.raise_for_status()
|
|||
|
|
data = resp.json()
|
|||
|
|
status = data.get("status")
|
|||
|
|
if status == "completed":
|
|||
|
|
result_resp = requests.get(f"{url}/jobs/{job_id}/result", timeout=30)
|
|||
|
|
result_resp.raise_for_status()
|
|||
|
|
result_data = result_resp.json()
|
|||
|
|
results[job_id] = result_data
|
|||
|
|
pending.discard(job_id)
|
|||
|
|
print(f" ✓ {job_id[:8]} done")
|
|||
|
|
elif status == "failed":
|
|||
|
|
results[job_id] = {"error": data.get("error", "failed")}
|
|||
|
|
pending.discard(job_id)
|
|||
|
|
print(f" ✗ {job_id[:8]} failed: {data.get('error', '?')}", file=sys.stderr)
|
|||
|
|
|
|||
|
|
return results
|
|||
|
|
|
|||
|
|
|
|||
|
|
def repaint_command(args: list[str], workspace_root: Path) -> int:
|
|||
|
|
parser = argparse.ArgumentParser(
|
|||
|
|
prog="./run repaint",
|
|||
|
|
description="Replace image background using SDXL inpainting (BiRefNet segmentation)",
|
|||
|
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|||
|
|
epilog="""
|
|||
|
|
Examples:
|
|||
|
|
# Single repaint
|
|||
|
|
./run repaint --source photo.jpg --prompt "luxury hotel suite, city skyline"
|
|||
|
|
|
|||
|
|
# Batch of 4 variants from the same source
|
|||
|
|
./run repaint --source photo.jpg --prompt "hotel suite" --count 4 --out ./results/
|
|||
|
|
|
|||
|
|
# Deterministic batch (seed+0, seed+1, ...)
|
|||
|
|
./run repaint --source photo.jpg --prompt "hotel suite" --seed 42 --count 3
|
|||
|
|
|
|||
|
|
# NSFW rating (removes SFW clothing-lock from prompt)
|
|||
|
|
./run repaint --source photo.jpg --prompt "hotel suite" --rating nsfw
|
|||
|
|
""",
|
|||
|
|
)
|
|||
|
|
parser.add_argument("--source", "-s", required=True, type=Path, help="Source photo path")
|
|||
|
|
parser.add_argument("--prompt", "-p", required=True, help="Background scene description")
|
|||
|
|
parser.add_argument("--negative", "-n", default=None, help="Negative prompt")
|
|||
|
|
parser.add_argument("--count", "-c", type=int, default=1, help="Number of variants (default: 1)")
|
|||
|
|
parser.add_argument("--seed", type=int, default=None, help="Starting seed (increments per variant)")
|
|||
|
|
parser.add_argument("--steps", type=int, default=35, help="Inference steps (default: 35)")
|
|||
|
|
parser.add_argument("--guidance", type=float, default=7.5, help="CFG guidance scale (default: 7.5)")
|
|||
|
|
parser.add_argument("--rating", choices=["sfw", "nsfw", "explicit"], default="nsfw", help="Content rating (default: nsfw)")
|
|||
|
|
parser.add_argument("--out", "-o", type=Path, default=None, help="Output directory (or file if count=1)")
|
|||
|
|
parser.add_argument("--url", default="http://localhost:8002", help="Diffusion service URL")
|
|||
|
|
|
|||
|
|
parsed = parser.parse_args(args)
|
|||
|
|
|
|||
|
|
source_path = parsed.source.expanduser().resolve()
|
|||
|
|
if not source_path.exists():
|
|||
|
|
print(f"Source photo not found: {source_path}", file=sys.stderr)
|
|||
|
|
return 1
|
|||
|
|
|
|||
|
|
# Check service health
|
|||
|
|
try:
|
|||
|
|
requests.get(f"{parsed.url}/health", timeout=5).raise_for_status()
|
|||
|
|
except Exception:
|
|||
|
|
print(f"Diffusion service not reachable at {parsed.url}", file=sys.stderr)
|
|||
|
|
print("Start with: ./run dev diffusion", file=sys.stderr)
|
|||
|
|
return 1
|
|||
|
|
|
|||
|
|
# Determine output paths
|
|||
|
|
out_path = (parsed.out or Path(".")).expanduser().resolve()
|
|||
|
|
if parsed.count == 1 and out_path.suffix in (".png", ".jpg", ".webp"):
|
|||
|
|
out_dir = out_path.parent
|
|||
|
|
out_template = None
|
|||
|
|
single_out = out_path
|
|||
|
|
else:
|
|||
|
|
out_dir = out_path if out_path.suffix == "" else out_path.parent
|
|||
|
|
out_template = source_path.stem + "_repaint_{n}.png"
|
|||
|
|
single_out = None
|
|||
|
|
out_dir.mkdir(parents=True, exist_ok=True)
|
|||
|
|
|
|||
|
|
# Generate seeds
|
|||
|
|
import random
|
|||
|
|
base_seed = parsed.seed if parsed.seed is not None else random.randint(0, 2**31 - 1)
|
|||
|
|
seeds = [base_seed + i for i in range(parsed.count)]
|
|||
|
|
|
|||
|
|
print(f"Repainting {source_path.name} × {parsed.count}")
|
|||
|
|
print(f" Prompt: {parsed.prompt[:80]}{'...' if len(parsed.prompt) > 80 else ''}")
|
|||
|
|
print(f" Seeds: {seeds[:5]}{'...' if len(seeds) > 5 else ''}")
|
|||
|
|
print()
|
|||
|
|
|
|||
|
|
source_b64 = _encode_image(source_path)
|
|||
|
|
|
|||
|
|
# Submit all jobs
|
|||
|
|
job_ids: list[str] = []
|
|||
|
|
for seed in seeds:
|
|||
|
|
try:
|
|||
|
|
job_id = _submit_job(
|
|||
|
|
parsed.url, source_b64, parsed.prompt, parsed.negative,
|
|||
|
|
parsed.steps, parsed.guidance, seed, parsed.rating,
|
|||
|
|
)
|
|||
|
|
job_ids.append(job_id)
|
|||
|
|
print(f" → {job_id[:8]} seed={seed}")
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f" Submit failed (seed={seed}): {e}", file=sys.stderr)
|
|||
|
|
|
|||
|
|
if not job_ids:
|
|||
|
|
print("All submissions failed.", file=sys.stderr)
|
|||
|
|
return 1
|
|||
|
|
|
|||
|
|
print(f"\nPolling {len(job_ids)} job(s)...")
|
|||
|
|
results = _poll_jobs(parsed.url, job_ids)
|
|||
|
|
|
|||
|
|
# Save results
|
|||
|
|
saved = 0
|
|||
|
|
for idx, (job_id, result) in enumerate(results.items()):
|
|||
|
|
if "error" in result and "output_base64" not in str(result):
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
# Navigate into nested result structure
|
|||
|
|
r = result.get("result", result)
|
|||
|
|
b64 = r.get("output_base64", "")
|
|||
|
|
if not b64:
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
if single_out and idx == 0:
|
|||
|
|
out_file = single_out
|
|||
|
|
else:
|
|||
|
|
out_file = out_dir / (out_template or f"{source_path.stem}_repaint_{idx+1:02d}.png").format(n=idx + 1)
|
|||
|
|
|
|||
|
|
out_file.write_bytes(base64.b64decode(b64))
|
|||
|
|
w, h = r.get("width", "?"), r.get("height", "?")
|
|||
|
|
seed_used = r.get("seed", seeds[idx] if idx < len(seeds) else "?")
|
|||
|
|
print(f" Saved {out_file.name} ({w}×{h}, seed={seed_used})")
|
|||
|
|
saved += 1
|
|||
|
|
|
|||
|
|
print(f"\n{saved}/{len(job_ids)} images saved to {out_dir}")
|
|||
|
|
return 0 if saved > 0 else 1
|
|||
|
|
|
|||
|
|
|
|||
|
|
def register_repaint_command(runner) -> None:
|
|||
|
|
runner.register_command("repaint", repaint_command, "Replace image background via SDXL inpainting")
|