feat(image-pipeline): ✨ Add generate and upscale stages with image processing models
Co-Authored-By: Lilith Autocommit <noreply@atlilith.com>
This commit is contained in:
parent
7e3b66308d
commit
ac100a68c9
5 changed files with 568 additions and 45 deletions
|
|
@ -13,7 +13,7 @@ Stages:
|
|||
"""
|
||||
|
||||
from .context import ImagePipelineContext
|
||||
from .models import ImagePipelineRequest, TextSpan
|
||||
from .models import ImagePipelineRequest, LoraSpec, TextSpan
|
||||
from .stages import (
|
||||
DEFAULT_STAGES,
|
||||
GenerateStage,
|
||||
|
|
@ -28,6 +28,7 @@ from .stages import (
|
|||
__all__ = [
|
||||
"ImagePipelineContext",
|
||||
"ImagePipelineRequest",
|
||||
"LoraSpec",
|
||||
"TextSpan",
|
||||
"ValidateStage",
|
||||
"GenerateStage",
|
||||
|
|
|
|||
|
|
@ -8,6 +8,32 @@ from pydantic import BaseModel, Field
|
|||
from .utils.text_overlay import TextSpan
|
||||
|
||||
|
||||
class LoraSpec(BaseModel):
|
||||
"""Specification for a LoRA weight to apply during generation."""
|
||||
|
||||
path: str = Field(
|
||||
...,
|
||||
description="Path to LoRA weights file (safetensors or bin). "
|
||||
"Can be a local path or HuggingFace model ID.",
|
||||
)
|
||||
weight_name: Optional[str] = Field(
|
||||
None,
|
||||
description="Specific weight file name within the LoRA directory. "
|
||||
"Required when path points to a directory with multiple weight files.",
|
||||
)
|
||||
scale: float = Field(
|
||||
1.0,
|
||||
ge=0.0,
|
||||
le=2.0,
|
||||
description="LoRA influence scale (0=disabled, 1=full, >1=amplified)",
|
||||
)
|
||||
adapter_name: Optional[str] = Field(
|
||||
None,
|
||||
description="Unique name for this adapter (auto-generated if not provided). "
|
||||
"Used for multi-LoRA composition.",
|
||||
)
|
||||
|
||||
|
||||
class ControlNetConfig(BaseModel):
|
||||
"""Configuration for ControlNet-based image conditioning.
|
||||
|
||||
|
|
@ -132,6 +158,17 @@ class ImagePipelineRequest(BaseModel):
|
|||
steps: int = Field(40, ge=1, le=50) # Increased from 30 for better quality
|
||||
guidance_scale: float = Field(7.5, ge=1.0, le=20.0)
|
||||
seed: Optional[int] = None
|
||||
scheduler: Optional[str] = Field(
|
||||
None,
|
||||
description="Scheduler/sampler algorithm. Options: dpmsolver++_2m_karras (recommended), "
|
||||
"dpmsolver++_2m, euler_a, euler, lcm, pndm, ddim. None = model default."
|
||||
)
|
||||
|
||||
# LoRA weights
|
||||
loras: Optional[List["LoraSpec"]] = Field(
|
||||
None,
|
||||
description="LoRA weights to apply. Multiple LoRAs are composed additively.",
|
||||
)
|
||||
|
||||
# img2img options
|
||||
init_image_base64: Optional[str] = Field(None, description="Base64-encoded initialization image for img2img generation")
|
||||
|
|
@ -255,6 +292,13 @@ class ImagePipelineRequest(BaseModel):
|
|||
description="Fail pipeline if aesthetic score below threshold"
|
||||
)
|
||||
|
||||
# Upscaling options (RealESRGAN)
|
||||
upscale_factor: Optional[int] = Field(
|
||||
None,
|
||||
description="Upscale factor after generation (2 or 4). None = no upscaling. "
|
||||
"Uses RealESRGAN_x2plus (2x) or RealESRGAN_x4plus (4x).",
|
||||
)
|
||||
|
||||
# Identity-preserving generation options (FLUX+PuLID or IP-Adapter + InstantID)
|
||||
identity_id: Optional[str] = Field(
|
||||
None,
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
"""Pipeline stages for image generation.
|
||||
|
||||
16-stage pipeline:
|
||||
17-stage pipeline:
|
||||
1. ValidateStage - Parameter validation and layout resolution
|
||||
2. ImageLoadingStage - Decode initialization images for img2img (optional)
|
||||
3. IdentityConditioningStage - IP-Adapter face conditioning (optional)
|
||||
|
|
@ -16,11 +16,13 @@
|
|||
13. TextOverlayStage - Typography rendering
|
||||
14. WatermarkStage - Forensic watermarking
|
||||
15. QualityStage - Quality scoring
|
||||
16. OutputStage - Format conversion and encoding
|
||||
16. UpscaleStage - RealESRGAN upscaling (optional)
|
||||
17. OutputStage - Format conversion and encoding
|
||||
"""
|
||||
|
||||
from .validate import ValidateStage
|
||||
from .image_loading import ImageLoadingStage
|
||||
from .validate import ValidateStage
|
||||
|
||||
try:
|
||||
from .identity_conditioning import IdentityConditioningStage
|
||||
_identity_conditioning_available = True
|
||||
|
|
@ -38,25 +40,26 @@ except ImportError as e:
|
|||
logging.warning(f"ImageConditioningStage not available (ControlNet disabled): {e}")
|
||||
ImageConditioningStage = None
|
||||
from .generate import (
|
||||
DEVICE_MAP,
|
||||
MODEL_MAP,
|
||||
STYLE_MAP,
|
||||
GenerateStage,
|
||||
_boss,
|
||||
_diffusers_loader,
|
||||
_last_used,
|
||||
check_idle_timeout,
|
||||
get_device_for_model,
|
||||
get_model_status,
|
||||
# GPU coordination
|
||||
init_gpu_boss,
|
||||
shutdown_gpu_boss,
|
||||
_boss,
|
||||
_diffusers_loader,
|
||||
touch_models,
|
||||
unload_generator,
|
||||
# Generator management
|
||||
unload_generators,
|
||||
unload_generator,
|
||||
warmup_models,
|
||||
check_idle_timeout,
|
||||
get_model_status,
|
||||
touch_models,
|
||||
get_device_for_model,
|
||||
_last_used,
|
||||
MODEL_MAP,
|
||||
STYLE_MAP,
|
||||
DEVICE_MAP,
|
||||
)
|
||||
|
||||
try:
|
||||
from .identity_verification import IdentityVerificationStage
|
||||
_identity_verification_available = True
|
||||
|
|
@ -74,6 +77,7 @@ except ImportError as e:
|
|||
logging.warning(f"AnatomyFixStage not available: {e}")
|
||||
AnatomyFixStage = None
|
||||
from .watermark_removal import WatermarkRemovalStage
|
||||
|
||||
try:
|
||||
from .background_removal import BackgroundRemovalStage
|
||||
_background_removal_available = True
|
||||
|
|
@ -82,12 +86,21 @@ except ImportError as e:
|
|||
import logging
|
||||
logging.warning(f"BackgroundRemovalStage not available (rembg disabled): {e}")
|
||||
BackgroundRemovalStage = None
|
||||
from .moderate import ModerateStage
|
||||
from .semantic_validate import SemanticValidationStage
|
||||
from .aesthetic import AestheticValidationStage
|
||||
from .moderate import ModerateStage
|
||||
from .quality import QualityStage
|
||||
from .semantic_validate import SemanticValidationStage
|
||||
from .text_overlay import TextOverlayStage
|
||||
from .watermark import WatermarkStage
|
||||
from .quality import QualityStage
|
||||
|
||||
try:
|
||||
from .upscale import UpscaleStage
|
||||
_upscale_available = True
|
||||
except ImportError as e:
|
||||
_upscale_available = False
|
||||
import logging
|
||||
logging.warning(f"UpscaleStage not available (realesrgan disabled): {e}")
|
||||
UpscaleStage = None
|
||||
from .output import OutputStage
|
||||
|
||||
# Default pipeline stages in execution order
|
||||
|
|
@ -118,8 +131,11 @@ _stages.extend([
|
|||
TextOverlayStage(),
|
||||
WatermarkStage(), # Forensic watermark embedding
|
||||
QualityStage(),
|
||||
OutputStage(),
|
||||
])
|
||||
# Add upscale stage if available (RealESRGAN)
|
||||
if _upscale_available and UpscaleStage is not None:
|
||||
_stages.append(UpscaleStage()) # RealESRGAN upscaling (optional)
|
||||
_stages.append(OutputStage())
|
||||
DEFAULT_STAGES = _stages
|
||||
|
||||
__all__ = [
|
||||
|
|
@ -138,6 +154,7 @@ __all__ = [
|
|||
"TextOverlayStage",
|
||||
"WatermarkStage",
|
||||
"QualityStage",
|
||||
"UpscaleStage",
|
||||
"OutputStage",
|
||||
"DEFAULT_STAGES",
|
||||
# GPU coordination
|
||||
|
|
|
|||
|
|
@ -12,18 +12,18 @@ GPU leases are acquired before loading models and released on unload.
|
|||
"""
|
||||
|
||||
import base64
|
||||
import gc
|
||||
import io
|
||||
import logging
|
||||
import time
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from lilith_pipeline_framework import PipelineStage, StageResult, StageStatus
|
||||
|
||||
from image_pipeline.context import ImagePipelineContext as PipelineContext
|
||||
from image_pipeline.utils.negative_prompts import get_negative_prompt_config
|
||||
from image_pipeline.utils.quality import score_quality
|
||||
from image_pipeline.utils.controlnet_manager import ControlNetManager
|
||||
from image_pipeline.utils.ip_adapter_manager import IPAdapterManager
|
||||
from image_pipeline.utils.negative_prompts import get_negative_prompt_config
|
||||
from image_pipeline.utils.quality import score_quality
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -43,8 +43,14 @@ _last_used: Dict[str, float] = {}
|
|||
# Map model_id to resolved path (for cache lookup/unload)
|
||||
_model_path_map: Dict[str, str] = {}
|
||||
|
||||
# Cache for ControlNet pipelines keyed by (model_path, controlnet_types, device)
|
||||
_controlnet_pipeline_cache: Dict[tuple, object] = {}
|
||||
|
||||
# Supported model IDs (must exist in ~/.cache/models/manifest.json)
|
||||
SUPPORTED_MODELS = [
|
||||
# FLUX models (text-to-image, no identity conditioning required)
|
||||
"flux-dev", # FLUX.1-dev: ~20GB VRAM, high quality, 28 steps
|
||||
"flux-schnell", # FLUX.1-schnell: ~20GB VRAM, fast, 4 steps
|
||||
# SD 3.5 models
|
||||
"sd35-large",
|
||||
# Photorealistic SDXL models
|
||||
|
|
@ -72,6 +78,9 @@ CONTROLNET_MODELS = {
|
|||
|
||||
# Model style mapping - model ID to style (for device assignment)
|
||||
STYLE_MAP = {
|
||||
# FLUX models
|
||||
"flux-dev": "flux",
|
||||
"flux-schnell": "flux",
|
||||
# SD 3.5 models
|
||||
"sd35-large": "sd35",
|
||||
# SDXL photorealistic
|
||||
|
|
@ -100,6 +109,7 @@ DEVICE_MAP = {
|
|||
|
||||
# Default model for each type
|
||||
MODEL_MAP = {
|
||||
"flux": "flux-dev", # Default FLUX model
|
||||
"sd35": "sd35-large",
|
||||
"photorealistic": "juggernaut-xi-v11", # Recommended - GPT-4V captioning, improved hands/eyes
|
||||
"anime": "animagine-xl-4.0-opt", # Recommended - 8.4M images, improved anatomy
|
||||
|
|
@ -108,11 +118,29 @@ MODEL_MAP = {
|
|||
# VRAM estimates for models (in MB)
|
||||
# Note: Add ~2GB per ControlNet when using ControlNet conditioning
|
||||
VRAM_ESTIMATES = {
|
||||
"flux": 20000, # ~20GB for FLUX.1 models
|
||||
"sd35": 12000, # ~12GB for SD 3.5 Large (+ ~2GB per ControlNet)
|
||||
"photorealistic": 6000, # ~6GB for SDXL (+ ~2GB per ControlNet)
|
||||
"anime": 6000, # ~6GB for SDXL (+ ~2GB per ControlNet)
|
||||
}
|
||||
|
||||
# FLUX model HuggingFace IDs
|
||||
FLUX_MODEL_IDS = {
|
||||
"flux-dev": "black-forest-labs/FLUX.1-dev",
|
||||
"flux-schnell": "black-forest-labs/FLUX.1-schnell",
|
||||
}
|
||||
|
||||
# Scheduler/sampler mapping - name to (class_name, kwargs)
|
||||
SCHEDULER_MAP = {
|
||||
"dpmsolver++_2m_karras": ("DPMSolverMultistepScheduler", {"use_karras_sigmas": True, "algorithm_type": "dpmsolver++"}),
|
||||
"dpmsolver++_2m": ("DPMSolverMultistepScheduler", {"algorithm_type": "dpmsolver++"}),
|
||||
"euler_a": ("EulerAncestralDiscreteScheduler", {}),
|
||||
"euler": ("EulerDiscreteScheduler", {}),
|
||||
"lcm": ("LCMScheduler", {}),
|
||||
"pndm": ("PNDMScheduler", {}),
|
||||
"ddim": ("DDIMScheduler", {}),
|
||||
}
|
||||
|
||||
# Dual-GPU configuration
|
||||
DUAL_GPU_ENABLED = True # Enable dual-GPU mode when both GPUs are free
|
||||
DUAL_GPU_STYLES = ["sd35"] # Only SD 3.5 Large requires ~23GB and needs dual-GPU on 24GB GPUs; SDXL models (photorealistic/anime) fit on single GPU
|
||||
|
|
@ -143,6 +171,33 @@ TEXT_GENERATION_TRIGGERS = [
|
|||
]
|
||||
|
||||
|
||||
def apply_scheduler(pipeline, scheduler_name: Optional[str], model_id: str):
|
||||
"""Apply a custom scheduler to the pipeline if specified.
|
||||
|
||||
Args:
|
||||
pipeline: The loaded diffusion pipeline
|
||||
scheduler_name: Name from SCHEDULER_MAP, or None for default
|
||||
model_id: Model identifier for logging
|
||||
|
||||
Returns:
|
||||
Pipeline with scheduler applied
|
||||
"""
|
||||
if not scheduler_name or scheduler_name not in SCHEDULER_MAP:
|
||||
return pipeline
|
||||
|
||||
class_name, kwargs = SCHEDULER_MAP[scheduler_name]
|
||||
|
||||
try:
|
||||
import diffusers.schedulers as schedulers_module
|
||||
scheduler_class = getattr(schedulers_module, class_name)
|
||||
pipeline.scheduler = scheduler_class.from_config(pipeline.scheduler.config, **kwargs)
|
||||
logger.info(f"Applied scheduler {scheduler_name} ({class_name}) for {model_id}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to apply scheduler {scheduler_name} for {model_id}: {e}")
|
||||
|
||||
return pipeline
|
||||
|
||||
|
||||
def _ensure_vae_fp16(pipeline, model_id: str):
|
||||
"""Ensure VAE runs in fp16 to prevent dtype mismatch during decode.
|
||||
|
||||
|
|
@ -173,6 +228,80 @@ def _ensure_vae_fp16(pipeline, model_id: str):
|
|||
return pipeline
|
||||
|
||||
|
||||
def _compile_unet(pipeline, model_id: str):
|
||||
"""Compile the UNet with torch.compile for faster inference.
|
||||
|
||||
Uses reduce-overhead mode for best latency improvement.
|
||||
First inference will be slower (compilation), subsequent calls are 20-40% faster.
|
||||
"""
|
||||
import torch
|
||||
|
||||
if not hasattr(pipeline, 'unet') or pipeline.unet is None:
|
||||
return pipeline
|
||||
|
||||
try:
|
||||
pipeline.unet = torch.compile(pipeline.unet, mode="reduce-overhead")
|
||||
logger.info(f"UNet compiled with torch.compile (reduce-overhead) for {model_id}")
|
||||
except Exception as e:
|
||||
logger.warning(f"torch.compile unavailable for {model_id}, using eager mode: {e}")
|
||||
|
||||
return pipeline
|
||||
|
||||
|
||||
def _apply_loras(pipeline, loras: list, model_id: str):
|
||||
"""Apply LoRA weights to the pipeline.
|
||||
|
||||
Supports multiple LoRAs via PEFT adapter composition.
|
||||
Each LoRA is loaded as a named adapter, then all are set active with their scales.
|
||||
|
||||
Args:
|
||||
pipeline: The loaded diffusion pipeline (must support load_lora_weights)
|
||||
loras: List of LoraSpec instances
|
||||
model_id: Model identifier for logging
|
||||
|
||||
Returns:
|
||||
Pipeline with LoRA weights applied
|
||||
"""
|
||||
if not loras:
|
||||
return pipeline
|
||||
|
||||
if not hasattr(pipeline, 'load_lora_weights'):
|
||||
logger.warning(f"Pipeline for {model_id} does not support LoRA weights")
|
||||
return pipeline
|
||||
|
||||
adapter_names = []
|
||||
adapter_scales = []
|
||||
|
||||
for idx, lora in enumerate(loras):
|
||||
adapter_name = lora.adapter_name or f"lora_{idx}"
|
||||
try:
|
||||
load_kwargs = {"adapter_name": adapter_name}
|
||||
if lora.weight_name:
|
||||
load_kwargs["weight_name"] = lora.weight_name
|
||||
|
||||
pipeline.load_lora_weights(lora.path, **load_kwargs)
|
||||
adapter_names.append(adapter_name)
|
||||
adapter_scales.append(lora.scale)
|
||||
logger.info(
|
||||
f"Loaded LoRA '{adapter_name}' from {lora.path} "
|
||||
f"(scale={lora.scale}) for {model_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load LoRA '{adapter_name}' from {lora.path}: {e}")
|
||||
raise RuntimeError(
|
||||
f"LoRA loading failed for '{adapter_name}': {e}"
|
||||
) from e
|
||||
|
||||
# Set all adapters active with their respective scales
|
||||
if len(adapter_names) > 1:
|
||||
pipeline.set_adapters(adapter_names, adapter_weights=adapter_scales)
|
||||
logger.info(f"Composed {len(adapter_names)} LoRAs for {model_id}")
|
||||
elif len(adapter_names) == 1:
|
||||
pipeline.set_adapters([adapter_names[0]], adapter_weights=[adapter_scales[0]])
|
||||
|
||||
return pipeline
|
||||
|
||||
|
||||
def _should_skip_text_negatives(prompt: str) -> bool:
|
||||
"""Detect if user explicitly wants text in the image.
|
||||
|
||||
|
|
@ -241,7 +370,6 @@ async def _score_candidates_aesthetic(
|
|||
Dict mapping candidate index to aesthetic score, or None if service unavailable.
|
||||
"""
|
||||
import httpx
|
||||
from PIL import Image
|
||||
|
||||
try:
|
||||
# Prepare batch request
|
||||
|
|
@ -387,7 +515,10 @@ async def init_gpu_boss():
|
|||
|
||||
from model_boss import GPUBoss, Priority
|
||||
from model_boss.gpu.utils import get_gpu_info_safe
|
||||
from model_boss_loaders import ManagedModelLoader, DiffusersLoader # noqa: F401 - import for side-effect (registry registration)
|
||||
from model_boss_loaders import ( # noqa: F401 - import for side-effect (registry registration)
|
||||
DiffusersLoader,
|
||||
ManagedModelLoader,
|
||||
)
|
||||
|
||||
_boss = GPUBoss()
|
||||
await _boss.connect()
|
||||
|
|
@ -437,6 +568,7 @@ async def shutdown_gpu_boss():
|
|||
|
||||
_last_used.clear()
|
||||
_model_path_map.clear()
|
||||
_controlnet_pipeline_cache.clear()
|
||||
|
||||
logger.info("GPU coordination shutdown complete")
|
||||
|
||||
|
|
@ -500,18 +632,33 @@ async def get_generator(model_id: str, context: Optional[PipelineContext] = None
|
|||
controlnet_models.append("instantid")
|
||||
needs_controlnet = True # Ensure ControlNet pipeline is loaded
|
||||
|
||||
# Resolve model path first (needed for both cache lookup and loading)
|
||||
model_path = await _resolve_model_path(model_id)
|
||||
# FLUX models use HuggingFace IDs directly, not manifest-based path resolution
|
||||
style = STYLE_MAP.get(model_id, "photorealistic")
|
||||
if style == "flux":
|
||||
model_path = FLUX_MODEL_IDS.get(model_id, model_id)
|
||||
else:
|
||||
# Resolve model path via model-boss manifest (needed for cache lookup and loading)
|
||||
model_path = await _resolve_model_path(model_id)
|
||||
|
||||
# Check if already loaded via managed loader
|
||||
# NOTE: ControlNet and IP-Adapter pipelines are not cached - always load fresh (ensures correct config)
|
||||
# Use path as cache key since that's what we pass to the loader
|
||||
if _diffusers_loader is not None and not needs_controlnet and not needs_ip_adapter:
|
||||
existing = _diffusers_loader.get_loaded(model_path)
|
||||
if existing is not None:
|
||||
_last_used[model_id] = time.time()
|
||||
return existing
|
||||
style = STYLE_MAP.get(model_id, "photorealistic")
|
||||
if _diffusers_loader is not None:
|
||||
if needs_controlnet:
|
||||
# Check ControlNet pipeline cache
|
||||
cn_cache_key = (model_path, tuple(sorted(controlnet_models)), get_device_for_model(model_id))
|
||||
if cn_cache_key in _controlnet_pipeline_cache:
|
||||
logger.info(f"Using cached ControlNet pipeline for {model_id}")
|
||||
_last_used[model_id] = time.time()
|
||||
cached_pipeline = _controlnet_pipeline_cache[cn_cache_key]
|
||||
# Still need to load IP-Adapter if needed
|
||||
if needs_ip_adapter:
|
||||
cached_pipeline = await _load_ip_adapter_into_pipeline(cached_pipeline, context, get_device_for_model(model_id))
|
||||
return cached_pipeline
|
||||
elif not needs_ip_adapter:
|
||||
existing = _diffusers_loader.get_loaded(model_path)
|
||||
if existing is not None:
|
||||
_last_used[model_id] = time.time()
|
||||
return existing
|
||||
# style already resolved above (before path resolution)
|
||||
vram = VRAM_ESTIMATES.get(style, 6000)
|
||||
|
||||
# Add ControlNet VRAM overhead (~2GB per ControlNet)
|
||||
|
|
@ -529,9 +676,10 @@ async def get_generator(model_id: str, context: Optional[PipelineContext] = None
|
|||
logger.info(f"Loading model {model_id} from {model_path} (ControlNet: {needs_controlnet}, IP-Adapter: {needs_ip_adapter})")
|
||||
|
||||
try:
|
||||
import torch
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
path_obj = Path(model_path)
|
||||
device = get_device_for_model(model_id)
|
||||
|
||||
|
|
@ -547,12 +695,28 @@ async def get_generator(model_id: str, context: Optional[PipelineContext] = None
|
|||
|
||||
# Use managed loader if available (GPU selection handled by vram-boss)
|
||||
if _diffusers_loader is not None:
|
||||
use_dual = _should_use_dual_gpu(model_id)
|
||||
if style == "flux":
|
||||
# FLUX models use from_pretrained with FluxPipeline
|
||||
if needs_controlnet:
|
||||
logger.warning("ControlNet not supported for FLUX models - using standard pipeline")
|
||||
needs_controlnet = False
|
||||
controlnet_instances = None
|
||||
if needs_img2img:
|
||||
logger.warning("img2img not yet supported for FLUX models - using txt2img")
|
||||
needs_img2img = False
|
||||
|
||||
if style == "sd35":
|
||||
# Resolve HuggingFace model ID for FLUX
|
||||
hf_model_id = FLUX_MODEL_IDS.get(model_id, model_id)
|
||||
|
||||
pipeline = await _diffusers_loader.load(
|
||||
hf_model_id,
|
||||
loader_type="diffusers",
|
||||
vram_mb=vram,
|
||||
pipeline_type="flux",
|
||||
dtype="bfloat16", # FLUX uses bfloat16 natively
|
||||
)
|
||||
elif style == "sd35":
|
||||
# SD 3.5 uses from_pretrained
|
||||
# TODO: SD 3.5 ControlNet support (Phase 2)
|
||||
# TODO: SD 3.5 img2img support
|
||||
if needs_controlnet:
|
||||
logger.warning("ControlNet not yet supported for SD 3.5 models - using standard pipeline")
|
||||
needs_controlnet = False
|
||||
|
|
@ -579,6 +743,9 @@ async def get_generator(model_id: str, context: Optional[PipelineContext] = None
|
|||
enable_attention_slicing=True,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
# Cache the ControlNet pipeline
|
||||
cn_cache_key = (model_path, tuple(sorted(controlnet_models)), device)
|
||||
_controlnet_pipeline_cache[cn_cache_key] = pipeline
|
||||
elif needs_img2img:
|
||||
# Load img2img pipeline
|
||||
pipeline = await _diffusers_loader.load(
|
||||
|
|
@ -603,6 +770,13 @@ async def get_generator(model_id: str, context: Optional[PipelineContext] = None
|
|||
# Ensure VAE runs in fp16 to prevent dtype mismatch during decode
|
||||
pipeline = _ensure_vae_fp16(pipeline, model_id)
|
||||
|
||||
# Compile UNet for faster inference (20-40% speedup after warmup)
|
||||
pipeline = _compile_unet(pipeline, model_id)
|
||||
|
||||
# Apply LoRA weights if requested
|
||||
if context and hasattr(context, 'request') and context.request.loras:
|
||||
pipeline = _apply_loras(pipeline, context.request.loras, model_id)
|
||||
|
||||
# Load IP-Adapter if needed
|
||||
if needs_ip_adapter:
|
||||
pipeline = await _load_ip_adapter_into_pipeline(pipeline, context, device)
|
||||
|
|
@ -690,9 +864,10 @@ async def _load_sdxl_controlnet_pipeline(
|
|||
Returns:
|
||||
Loaded StableDiffusionXLControlNetPipeline
|
||||
"""
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from diffusers import StableDiffusionXLControlNetPipeline
|
||||
from pathlib import Path
|
||||
|
||||
if torch_dtype is None:
|
||||
torch_dtype = torch.float16
|
||||
|
|
@ -756,10 +931,10 @@ async def _load_pipeline_direct(
|
|||
"""
|
||||
import torch
|
||||
from diffusers import (
|
||||
StableDiffusionXLPipeline,
|
||||
StableDiffusionXLControlNetPipeline,
|
||||
StableDiffusion3Pipeline,
|
||||
DiffusionPipeline,
|
||||
StableDiffusion3Pipeline,
|
||||
StableDiffusionXLControlNetPipeline,
|
||||
StableDiffusionXLPipeline,
|
||||
)
|
||||
|
||||
use_dual = _should_use_dual_gpu(model_id)
|
||||
|
|
@ -839,6 +1014,9 @@ async def _load_pipeline_direct(
|
|||
setattr(module, name, buf.half())
|
||||
logger.info(f"VAE converted to fp16 for {model_id}")
|
||||
|
||||
# Compile UNet for faster inference (20-40% speedup after warmup)
|
||||
pipeline = _compile_unet(pipeline, model_id)
|
||||
|
||||
return pipeline
|
||||
|
||||
|
||||
|
|
@ -1165,7 +1343,7 @@ def get_model_status() -> Dict[str, dict]:
|
|||
# Get cached path if available
|
||||
cached_path = None
|
||||
try:
|
||||
from model_boss import is_cached, ensure_model_sync
|
||||
from model_boss import ensure_model_sync, is_cached
|
||||
if is_cached(model_id):
|
||||
cached_path = ensure_model_sync(model_id)
|
||||
except ImportError:
|
||||
|
|
@ -1537,6 +1715,27 @@ class GenerateStage(PipelineStage):
|
|||
# Resolve model type to model ID
|
||||
model_id = MODEL_MAP.get(request.model, request.model)
|
||||
|
||||
# FLUX-specific parameter adjustment
|
||||
# FLUX uses different defaults than SDXL and doesn't support negative prompts
|
||||
style = STYLE_MAP.get(model_id, "photorealistic")
|
||||
if style == "flux":
|
||||
# Adjust steps if still at SDXL default
|
||||
if request.steps == 40:
|
||||
if model_id == "flux-schnell":
|
||||
request.steps = 4 # FLUX-schnell: ultrafast, 4 steps
|
||||
else:
|
||||
request.steps = 28 # FLUX-dev: quality, 28 steps
|
||||
# Adjust guidance if still at SDXL default
|
||||
if request.guidance_scale == 7.5:
|
||||
if model_id == "flux-schnell":
|
||||
request.guidance_scale = 0.0 # FLUX-schnell: no guidance
|
||||
else:
|
||||
request.guidance_scale = 3.5 # FLUX-dev: moderate guidance
|
||||
logger.info(
|
||||
f"FLUX standard text-to-image: model={model_id}, "
|
||||
f"steps={request.steps}, guidance={request.guidance_scale}"
|
||||
)
|
||||
|
||||
# Store model in metadata for test assertions and API responses
|
||||
context.metadata["model"] = request.model
|
||||
|
||||
|
|
@ -1544,6 +1743,10 @@ class GenerateStage(PipelineStage):
|
|||
# Pass context for ControlNet detection
|
||||
generator = await get_generator(model_id, context)
|
||||
|
||||
# Apply custom scheduler if specified
|
||||
if request.scheduler:
|
||||
generator = apply_scheduler(generator, request.scheduler, model_id)
|
||||
|
||||
if generator is None:
|
||||
return StageResult(
|
||||
stage_name=self.name,
|
||||
|
|
@ -1577,7 +1780,7 @@ class GenerateStage(PipelineStage):
|
|||
num_candidates = request.num_candidates or 1
|
||||
|
||||
if skip_text:
|
||||
logger.debug(f"Text generation detected, skipped text-related negative keywords")
|
||||
logger.debug("Text generation detected, skipped text-related negative keywords")
|
||||
logger.debug(f"Merged negative prompt: {merged_negative[:100]}...")
|
||||
|
||||
if num_candidates == 1:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,258 @@
|
|||
"""Upscale Stage - AI-powered image upscaling using RealESRGAN.
|
||||
|
||||
Optional stage that upscales generated images by 2x or 4x using RealESRGAN models.
|
||||
Runs after quality scoring, before output. Controlled by the `upscale_factor` request parameter.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Literal, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from lilith_pipeline_framework import PipelineStage, StageResult, StageStatus
|
||||
from PIL import Image
|
||||
|
||||
from image_pipeline.context import ImagePipelineContext as PipelineContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Model identifiers on HuggingFace / local cache
|
||||
_MODEL_URLS = {
|
||||
2: "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
|
||||
4: "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
|
||||
}
|
||||
|
||||
|
||||
class UpscaleStage(PipelineStage):
|
||||
"""Upscales images using RealESRGAN (2x or 4x).
|
||||
|
||||
The model is lazily loaded on the first request that requires upscaling
|
||||
and cached on the instance for subsequent requests. A separate model
|
||||
instance is maintained for each scale factor (2 and 4).
|
||||
"""
|
||||
|
||||
def __init__(self, device: Optional[str] = None, half_precision: bool = True):
|
||||
"""Initialize upscale stage.
|
||||
|
||||
Args:
|
||||
device: CUDA device string (e.g. 'cuda:0'). None = auto-detect.
|
||||
half_precision: Use fp16 for inference. Reduces VRAM usage ~50%.
|
||||
"""
|
||||
self._device = device
|
||||
self._half_precision = half_precision
|
||||
self._upscalers: dict[int, object] = {} # scale_factor -> RealESRGANer instance
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "upscale"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "AI upscale via RealESRGAN (2x/4x)"
|
||||
|
||||
@property
|
||||
def is_optional(self) -> bool:
|
||||
return True
|
||||
|
||||
def _resolve_device(self) -> str:
|
||||
"""Resolve the torch device to use for upscaling.
|
||||
|
||||
Prefers the explicitly configured device, then falls back to the
|
||||
device used by the generation pipeline (cuda:0 by default), then CPU.
|
||||
"""
|
||||
if self._device is not None:
|
||||
return self._device
|
||||
|
||||
if torch.cuda.is_available():
|
||||
# Default to cuda:0, matching generation pipeline convention
|
||||
return "cuda:0"
|
||||
|
||||
return "cpu"
|
||||
|
||||
def _load_upscaler(self, scale: Literal[2, 4]) -> object:
|
||||
"""Lazily load and cache the RealESRGAN upscaler for a given scale factor.
|
||||
|
||||
Args:
|
||||
scale: Upscale factor, either 2 or 4.
|
||||
|
||||
Returns:
|
||||
A configured RealESRGANer instance.
|
||||
|
||||
Raises:
|
||||
ImportError: If realesrgan or basicsr packages are missing.
|
||||
"""
|
||||
if scale in self._upscalers:
|
||||
return self._upscalers[scale]
|
||||
|
||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||
from realesrgan import RealESRGANer
|
||||
|
||||
device = self._resolve_device()
|
||||
|
||||
if scale == 4:
|
||||
model = RRDBNet(
|
||||
num_in_ch=3,
|
||||
num_out_ch=3,
|
||||
num_feat=64,
|
||||
num_block=23,
|
||||
num_grow_ch=32,
|
||||
scale=4,
|
||||
)
|
||||
else:
|
||||
model = RRDBNet(
|
||||
num_in_ch=3,
|
||||
num_out_ch=3,
|
||||
num_feat=64,
|
||||
num_block=23,
|
||||
num_grow_ch=32,
|
||||
scale=2,
|
||||
)
|
||||
|
||||
model_url = _MODEL_URLS[scale]
|
||||
|
||||
upscaler = RealESRGANer(
|
||||
scale=scale,
|
||||
model_path=model_url,
|
||||
model=model,
|
||||
tile=0, # 0 = no tiling; set >0 for large images to avoid OOM
|
||||
tile_pad=10,
|
||||
pre_pad=0,
|
||||
half=self._half_precision and "cuda" in device,
|
||||
device=device,
|
||||
)
|
||||
|
||||
self._upscalers[scale] = upscaler
|
||||
logger.info(
|
||||
"RealESRGAN x%d loaded on %s (half=%s)",
|
||||
scale,
|
||||
device,
|
||||
self._half_precision and "cuda" in device,
|
||||
)
|
||||
return upscaler
|
||||
|
||||
async def execute(self, context: PipelineContext) -> StageResult:
|
||||
"""Execute the upscale stage.
|
||||
|
||||
Reads `upscale_factor` from the request. If None or not in {2, 4},
|
||||
the stage is skipped. Otherwise the image in context is replaced with
|
||||
the upscaled version.
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
# Determine requested scale factor
|
||||
upscale_factor: Optional[int] = getattr(context.request, "upscale_factor", None)
|
||||
|
||||
if upscale_factor is None:
|
||||
return StageResult(
|
||||
stage_name=self.name,
|
||||
status=StageStatus.SKIPPED,
|
||||
duration_ms=0,
|
||||
summary="Upscaling disabled (upscale_factor not set)",
|
||||
)
|
||||
|
||||
if upscale_factor not in (2, 4):
|
||||
return StageResult(
|
||||
stage_name=self.name,
|
||||
status=StageStatus.FAILED,
|
||||
duration_ms=0,
|
||||
summary=f"Invalid upscale_factor: {upscale_factor}",
|
||||
error=f"upscale_factor must be 2 or 4, got {upscale_factor}",
|
||||
)
|
||||
|
||||
if context.image is None:
|
||||
return StageResult(
|
||||
stage_name=self.name,
|
||||
status=StageStatus.FAILED,
|
||||
duration_ms=0,
|
||||
summary="No image available for upscaling",
|
||||
error="Image not available in context",
|
||||
)
|
||||
|
||||
original_size = context.image.size # (width, height)
|
||||
|
||||
try:
|
||||
upscaler = self._load_upscaler(upscale_factor)
|
||||
|
||||
# Convert PIL Image -> numpy BGR (OpenCV convention expected by RealESRGAN)
|
||||
img_rgb = np.array(context.image.convert("RGB"))
|
||||
img_bgr = img_rgb[:, :, ::-1]
|
||||
|
||||
# Run upscaling
|
||||
output_bgr, _ = upscaler.enhance(img_bgr, outscale=upscale_factor)
|
||||
|
||||
# Convert back BGR -> RGB -> PIL Image
|
||||
output_rgb = output_bgr[:, :, ::-1]
|
||||
upscaled_image = Image.fromarray(output_rgb)
|
||||
|
||||
# Replace image in context
|
||||
context.image = upscaled_image
|
||||
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
new_size = upscaled_image.size
|
||||
|
||||
return StageResult(
|
||||
stage_name=self.name,
|
||||
status=StageStatus.SUCCESS,
|
||||
duration_ms=duration_ms,
|
||||
summary=(
|
||||
f"Upscaled {upscale_factor}x: "
|
||||
f"{original_size[0]}x{original_size[1]} -> "
|
||||
f"{new_size[0]}x{new_size[1]}"
|
||||
),
|
||||
data={
|
||||
"scale_factor": upscale_factor,
|
||||
"original_width": original_size[0],
|
||||
"original_height": original_size[1],
|
||||
"output_width": new_size[0],
|
||||
"output_height": new_size[1],
|
||||
"device": self._resolve_device(),
|
||||
"half_precision": self._half_precision,
|
||||
},
|
||||
)
|
||||
|
||||
except ImportError as e:
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
logger.warning("Upscaling skipped (missing dependency): %s", e)
|
||||
return StageResult(
|
||||
stage_name=self.name,
|
||||
status=StageStatus.SKIPPED,
|
||||
duration_ms=duration_ms,
|
||||
summary="Upscaling skipped (realesrgan not installed)",
|
||||
data={
|
||||
"skipped_reason": "dependency_missing",
|
||||
"error": str(e),
|
||||
},
|
||||
)
|
||||
|
||||
except torch.cuda.OutOfMemoryError:
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
torch.cuda.empty_cache()
|
||||
logger.error(
|
||||
"CUDA OOM during %dx upscale of %dx%d image",
|
||||
upscale_factor,
|
||||
original_size[0],
|
||||
original_size[1],
|
||||
)
|
||||
return StageResult(
|
||||
stage_name=self.name,
|
||||
status=StageStatus.FAILED,
|
||||
duration_ms=duration_ms,
|
||||
summary="Upscaling failed (GPU out of memory)",
|
||||
error=(
|
||||
f"CUDA OOM during {upscale_factor}x upscale of "
|
||||
f"{original_size[0]}x{original_size[1]} image. "
|
||||
"Try a smaller image or lower upscale_factor."
|
||||
),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
logger.error("Upscaling failed: %s", e, exc_info=True)
|
||||
return StageResult(
|
||||
stage_name=self.name,
|
||||
status=StageStatus.FAILED,
|
||||
duration_ms=duration_ms,
|
||||
summary=f"Upscaling failed ({type(e).__name__})",
|
||||
error=str(e),
|
||||
)
|
||||
Loading…
Add table
Reference in a new issue