diff --git a/orchestrators/imajin-pipeline/src/image_pipeline/__init__.py b/orchestrators/imajin-pipeline/src/image_pipeline/__init__.py index 8500a53f..07ba64c4 100644 --- a/orchestrators/imajin-pipeline/src/image_pipeline/__init__.py +++ b/orchestrators/imajin-pipeline/src/image_pipeline/__init__.py @@ -13,7 +13,7 @@ Stages: """ from .context import ImagePipelineContext -from .models import ImagePipelineRequest, TextSpan +from .models import ImagePipelineRequest, LoraSpec, TextSpan from .stages import ( DEFAULT_STAGES, GenerateStage, @@ -28,6 +28,7 @@ from .stages import ( __all__ = [ "ImagePipelineContext", "ImagePipelineRequest", + "LoraSpec", "TextSpan", "ValidateStage", "GenerateStage", diff --git a/orchestrators/imajin-pipeline/src/image_pipeline/models.py b/orchestrators/imajin-pipeline/src/image_pipeline/models.py index 3fcb955a..fafd6f1b 100644 --- a/orchestrators/imajin-pipeline/src/image_pipeline/models.py +++ b/orchestrators/imajin-pipeline/src/image_pipeline/models.py @@ -8,6 +8,32 @@ from pydantic import BaseModel, Field from .utils.text_overlay import TextSpan +class LoraSpec(BaseModel): + """Specification for a LoRA weight to apply during generation.""" + + path: str = Field( + ..., + description="Path to LoRA weights file (safetensors or bin). " + "Can be a local path or HuggingFace model ID.", + ) + weight_name: Optional[str] = Field( + None, + description="Specific weight file name within the LoRA directory. " + "Required when path points to a directory with multiple weight files.", + ) + scale: float = Field( + 1.0, + ge=0.0, + le=2.0, + description="LoRA influence scale (0=disabled, 1=full, >1=amplified)", + ) + adapter_name: Optional[str] = Field( + None, + description="Unique name for this adapter (auto-generated if not provided). " + "Used for multi-LoRA composition.", + ) + + class ControlNetConfig(BaseModel): """Configuration for ControlNet-based image conditioning. @@ -132,6 +158,17 @@ class ImagePipelineRequest(BaseModel): steps: int = Field(40, ge=1, le=50) # Increased from 30 for better quality guidance_scale: float = Field(7.5, ge=1.0, le=20.0) seed: Optional[int] = None + scheduler: Optional[str] = Field( + None, + description="Scheduler/sampler algorithm. Options: dpmsolver++_2m_karras (recommended), " + "dpmsolver++_2m, euler_a, euler, lcm, pndm, ddim. None = model default." + ) + + # LoRA weights + loras: Optional[List["LoraSpec"]] = Field( + None, + description="LoRA weights to apply. Multiple LoRAs are composed additively.", + ) # img2img options init_image_base64: Optional[str] = Field(None, description="Base64-encoded initialization image for img2img generation") @@ -255,6 +292,13 @@ class ImagePipelineRequest(BaseModel): description="Fail pipeline if aesthetic score below threshold" ) + # Upscaling options (RealESRGAN) + upscale_factor: Optional[int] = Field( + None, + description="Upscale factor after generation (2 or 4). None = no upscaling. " + "Uses RealESRGAN_x2plus (2x) or RealESRGAN_x4plus (4x).", + ) + # Identity-preserving generation options (FLUX+PuLID or IP-Adapter + InstantID) identity_id: Optional[str] = Field( None, diff --git a/orchestrators/imajin-pipeline/src/image_pipeline/stages/__init__.py b/orchestrators/imajin-pipeline/src/image_pipeline/stages/__init__.py index db0c8ad6..c97cba0e 100644 --- a/orchestrators/imajin-pipeline/src/image_pipeline/stages/__init__.py +++ b/orchestrators/imajin-pipeline/src/image_pipeline/stages/__init__.py @@ -1,6 +1,6 @@ """Pipeline stages for image generation. -16-stage pipeline: +17-stage pipeline: 1. ValidateStage - Parameter validation and layout resolution 2. ImageLoadingStage - Decode initialization images for img2img (optional) 3. IdentityConditioningStage - IP-Adapter face conditioning (optional) @@ -16,11 +16,13 @@ 13. TextOverlayStage - Typography rendering 14. WatermarkStage - Forensic watermarking 15. QualityStage - Quality scoring -16. OutputStage - Format conversion and encoding +16. UpscaleStage - RealESRGAN upscaling (optional) +17. OutputStage - Format conversion and encoding """ -from .validate import ValidateStage from .image_loading import ImageLoadingStage +from .validate import ValidateStage + try: from .identity_conditioning import IdentityConditioningStage _identity_conditioning_available = True @@ -38,25 +40,26 @@ except ImportError as e: logging.warning(f"ImageConditioningStage not available (ControlNet disabled): {e}") ImageConditioningStage = None from .generate import ( + DEVICE_MAP, + MODEL_MAP, + STYLE_MAP, GenerateStage, + _boss, + _diffusers_loader, + _last_used, + check_idle_timeout, + get_device_for_model, + get_model_status, # GPU coordination init_gpu_boss, shutdown_gpu_boss, - _boss, - _diffusers_loader, + touch_models, + unload_generator, # Generator management unload_generators, - unload_generator, warmup_models, - check_idle_timeout, - get_model_status, - touch_models, - get_device_for_model, - _last_used, - MODEL_MAP, - STYLE_MAP, - DEVICE_MAP, ) + try: from .identity_verification import IdentityVerificationStage _identity_verification_available = True @@ -74,6 +77,7 @@ except ImportError as e: logging.warning(f"AnatomyFixStage not available: {e}") AnatomyFixStage = None from .watermark_removal import WatermarkRemovalStage + try: from .background_removal import BackgroundRemovalStage _background_removal_available = True @@ -82,12 +86,21 @@ except ImportError as e: import logging logging.warning(f"BackgroundRemovalStage not available (rembg disabled): {e}") BackgroundRemovalStage = None -from .moderate import ModerateStage -from .semantic_validate import SemanticValidationStage from .aesthetic import AestheticValidationStage +from .moderate import ModerateStage +from .quality import QualityStage +from .semantic_validate import SemanticValidationStage from .text_overlay import TextOverlayStage from .watermark import WatermarkStage -from .quality import QualityStage + +try: + from .upscale import UpscaleStage + _upscale_available = True +except ImportError as e: + _upscale_available = False + import logging + logging.warning(f"UpscaleStage not available (realesrgan disabled): {e}") + UpscaleStage = None from .output import OutputStage # Default pipeline stages in execution order @@ -118,8 +131,11 @@ _stages.extend([ TextOverlayStage(), WatermarkStage(), # Forensic watermark embedding QualityStage(), - OutputStage(), ]) +# Add upscale stage if available (RealESRGAN) +if _upscale_available and UpscaleStage is not None: + _stages.append(UpscaleStage()) # RealESRGAN upscaling (optional) +_stages.append(OutputStage()) DEFAULT_STAGES = _stages __all__ = [ @@ -138,6 +154,7 @@ __all__ = [ "TextOverlayStage", "WatermarkStage", "QualityStage", + "UpscaleStage", "OutputStage", "DEFAULT_STAGES", # GPU coordination diff --git a/orchestrators/imajin-pipeline/src/image_pipeline/stages/generate.py b/orchestrators/imajin-pipeline/src/image_pipeline/stages/generate.py index 56d349b8..e8d69722 100644 --- a/orchestrators/imajin-pipeline/src/image_pipeline/stages/generate.py +++ b/orchestrators/imajin-pipeline/src/image_pipeline/stages/generate.py @@ -12,18 +12,18 @@ GPU leases are acquired before loading models and released on unload. """ import base64 -import gc import io import logging import time from typing import Dict, List, Optional from lilith_pipeline_framework import PipelineStage, StageResult, StageStatus + from image_pipeline.context import ImagePipelineContext as PipelineContext -from image_pipeline.utils.negative_prompts import get_negative_prompt_config -from image_pipeline.utils.quality import score_quality from image_pipeline.utils.controlnet_manager import ControlNetManager from image_pipeline.utils.ip_adapter_manager import IPAdapterManager +from image_pipeline.utils.negative_prompts import get_negative_prompt_config +from image_pipeline.utils.quality import score_quality logger = logging.getLogger(__name__) @@ -43,8 +43,14 @@ _last_used: Dict[str, float] = {} # Map model_id to resolved path (for cache lookup/unload) _model_path_map: Dict[str, str] = {} +# Cache for ControlNet pipelines keyed by (model_path, controlnet_types, device) +_controlnet_pipeline_cache: Dict[tuple, object] = {} + # Supported model IDs (must exist in ~/.cache/models/manifest.json) SUPPORTED_MODELS = [ + # FLUX models (text-to-image, no identity conditioning required) + "flux-dev", # FLUX.1-dev: ~20GB VRAM, high quality, 28 steps + "flux-schnell", # FLUX.1-schnell: ~20GB VRAM, fast, 4 steps # SD 3.5 models "sd35-large", # Photorealistic SDXL models @@ -72,6 +78,9 @@ CONTROLNET_MODELS = { # Model style mapping - model ID to style (for device assignment) STYLE_MAP = { + # FLUX models + "flux-dev": "flux", + "flux-schnell": "flux", # SD 3.5 models "sd35-large": "sd35", # SDXL photorealistic @@ -100,6 +109,7 @@ DEVICE_MAP = { # Default model for each type MODEL_MAP = { + "flux": "flux-dev", # Default FLUX model "sd35": "sd35-large", "photorealistic": "juggernaut-xi-v11", # Recommended - GPT-4V captioning, improved hands/eyes "anime": "animagine-xl-4.0-opt", # Recommended - 8.4M images, improved anatomy @@ -108,11 +118,29 @@ MODEL_MAP = { # VRAM estimates for models (in MB) # Note: Add ~2GB per ControlNet when using ControlNet conditioning VRAM_ESTIMATES = { + "flux": 20000, # ~20GB for FLUX.1 models "sd35": 12000, # ~12GB for SD 3.5 Large (+ ~2GB per ControlNet) "photorealistic": 6000, # ~6GB for SDXL (+ ~2GB per ControlNet) "anime": 6000, # ~6GB for SDXL (+ ~2GB per ControlNet) } +# FLUX model HuggingFace IDs +FLUX_MODEL_IDS = { + "flux-dev": "black-forest-labs/FLUX.1-dev", + "flux-schnell": "black-forest-labs/FLUX.1-schnell", +} + +# Scheduler/sampler mapping - name to (class_name, kwargs) +SCHEDULER_MAP = { + "dpmsolver++_2m_karras": ("DPMSolverMultistepScheduler", {"use_karras_sigmas": True, "algorithm_type": "dpmsolver++"}), + "dpmsolver++_2m": ("DPMSolverMultistepScheduler", {"algorithm_type": "dpmsolver++"}), + "euler_a": ("EulerAncestralDiscreteScheduler", {}), + "euler": ("EulerDiscreteScheduler", {}), + "lcm": ("LCMScheduler", {}), + "pndm": ("PNDMScheduler", {}), + "ddim": ("DDIMScheduler", {}), +} + # Dual-GPU configuration DUAL_GPU_ENABLED = True # Enable dual-GPU mode when both GPUs are free DUAL_GPU_STYLES = ["sd35"] # Only SD 3.5 Large requires ~23GB and needs dual-GPU on 24GB GPUs; SDXL models (photorealistic/anime) fit on single GPU @@ -143,6 +171,33 @@ TEXT_GENERATION_TRIGGERS = [ ] +def apply_scheduler(pipeline, scheduler_name: Optional[str], model_id: str): + """Apply a custom scheduler to the pipeline if specified. + + Args: + pipeline: The loaded diffusion pipeline + scheduler_name: Name from SCHEDULER_MAP, or None for default + model_id: Model identifier for logging + + Returns: + Pipeline with scheduler applied + """ + if not scheduler_name or scheduler_name not in SCHEDULER_MAP: + return pipeline + + class_name, kwargs = SCHEDULER_MAP[scheduler_name] + + try: + import diffusers.schedulers as schedulers_module + scheduler_class = getattr(schedulers_module, class_name) + pipeline.scheduler = scheduler_class.from_config(pipeline.scheduler.config, **kwargs) + logger.info(f"Applied scheduler {scheduler_name} ({class_name}) for {model_id}") + except Exception as e: + logger.warning(f"Failed to apply scheduler {scheduler_name} for {model_id}: {e}") + + return pipeline + + def _ensure_vae_fp16(pipeline, model_id: str): """Ensure VAE runs in fp16 to prevent dtype mismatch during decode. @@ -173,6 +228,80 @@ def _ensure_vae_fp16(pipeline, model_id: str): return pipeline +def _compile_unet(pipeline, model_id: str): + """Compile the UNet with torch.compile for faster inference. + + Uses reduce-overhead mode for best latency improvement. + First inference will be slower (compilation), subsequent calls are 20-40% faster. + """ + import torch + + if not hasattr(pipeline, 'unet') or pipeline.unet is None: + return pipeline + + try: + pipeline.unet = torch.compile(pipeline.unet, mode="reduce-overhead") + logger.info(f"UNet compiled with torch.compile (reduce-overhead) for {model_id}") + except Exception as e: + logger.warning(f"torch.compile unavailable for {model_id}, using eager mode: {e}") + + return pipeline + + +def _apply_loras(pipeline, loras: list, model_id: str): + """Apply LoRA weights to the pipeline. + + Supports multiple LoRAs via PEFT adapter composition. + Each LoRA is loaded as a named adapter, then all are set active with their scales. + + Args: + pipeline: The loaded diffusion pipeline (must support load_lora_weights) + loras: List of LoraSpec instances + model_id: Model identifier for logging + + Returns: + Pipeline with LoRA weights applied + """ + if not loras: + return pipeline + + if not hasattr(pipeline, 'load_lora_weights'): + logger.warning(f"Pipeline for {model_id} does not support LoRA weights") + return pipeline + + adapter_names = [] + adapter_scales = [] + + for idx, lora in enumerate(loras): + adapter_name = lora.adapter_name or f"lora_{idx}" + try: + load_kwargs = {"adapter_name": adapter_name} + if lora.weight_name: + load_kwargs["weight_name"] = lora.weight_name + + pipeline.load_lora_weights(lora.path, **load_kwargs) + adapter_names.append(adapter_name) + adapter_scales.append(lora.scale) + logger.info( + f"Loaded LoRA '{adapter_name}' from {lora.path} " + f"(scale={lora.scale}) for {model_id}" + ) + except Exception as e: + logger.error(f"Failed to load LoRA '{adapter_name}' from {lora.path}: {e}") + raise RuntimeError( + f"LoRA loading failed for '{adapter_name}': {e}" + ) from e + + # Set all adapters active with their respective scales + if len(adapter_names) > 1: + pipeline.set_adapters(adapter_names, adapter_weights=adapter_scales) + logger.info(f"Composed {len(adapter_names)} LoRAs for {model_id}") + elif len(adapter_names) == 1: + pipeline.set_adapters([adapter_names[0]], adapter_weights=[adapter_scales[0]]) + + return pipeline + + def _should_skip_text_negatives(prompt: str) -> bool: """Detect if user explicitly wants text in the image. @@ -241,7 +370,6 @@ async def _score_candidates_aesthetic( Dict mapping candidate index to aesthetic score, or None if service unavailable. """ import httpx - from PIL import Image try: # Prepare batch request @@ -387,7 +515,10 @@ async def init_gpu_boss(): from model_boss import GPUBoss, Priority from model_boss.gpu.utils import get_gpu_info_safe - from model_boss_loaders import ManagedModelLoader, DiffusersLoader # noqa: F401 - import for side-effect (registry registration) + from model_boss_loaders import ( # noqa: F401 - import for side-effect (registry registration) + DiffusersLoader, + ManagedModelLoader, + ) _boss = GPUBoss() await _boss.connect() @@ -437,6 +568,7 @@ async def shutdown_gpu_boss(): _last_used.clear() _model_path_map.clear() + _controlnet_pipeline_cache.clear() logger.info("GPU coordination shutdown complete") @@ -500,18 +632,33 @@ async def get_generator(model_id: str, context: Optional[PipelineContext] = None controlnet_models.append("instantid") needs_controlnet = True # Ensure ControlNet pipeline is loaded - # Resolve model path first (needed for both cache lookup and loading) - model_path = await _resolve_model_path(model_id) + # FLUX models use HuggingFace IDs directly, not manifest-based path resolution + style = STYLE_MAP.get(model_id, "photorealistic") + if style == "flux": + model_path = FLUX_MODEL_IDS.get(model_id, model_id) + else: + # Resolve model path via model-boss manifest (needed for cache lookup and loading) + model_path = await _resolve_model_path(model_id) # Check if already loaded via managed loader - # NOTE: ControlNet and IP-Adapter pipelines are not cached - always load fresh (ensures correct config) - # Use path as cache key since that's what we pass to the loader - if _diffusers_loader is not None and not needs_controlnet and not needs_ip_adapter: - existing = _diffusers_loader.get_loaded(model_path) - if existing is not None: - _last_used[model_id] = time.time() - return existing - style = STYLE_MAP.get(model_id, "photorealistic") + if _diffusers_loader is not None: + if needs_controlnet: + # Check ControlNet pipeline cache + cn_cache_key = (model_path, tuple(sorted(controlnet_models)), get_device_for_model(model_id)) + if cn_cache_key in _controlnet_pipeline_cache: + logger.info(f"Using cached ControlNet pipeline for {model_id}") + _last_used[model_id] = time.time() + cached_pipeline = _controlnet_pipeline_cache[cn_cache_key] + # Still need to load IP-Adapter if needed + if needs_ip_adapter: + cached_pipeline = await _load_ip_adapter_into_pipeline(cached_pipeline, context, get_device_for_model(model_id)) + return cached_pipeline + elif not needs_ip_adapter: + existing = _diffusers_loader.get_loaded(model_path) + if existing is not None: + _last_used[model_id] = time.time() + return existing + # style already resolved above (before path resolution) vram = VRAM_ESTIMATES.get(style, 6000) # Add ControlNet VRAM overhead (~2GB per ControlNet) @@ -529,9 +676,10 @@ async def get_generator(model_id: str, context: Optional[PipelineContext] = None logger.info(f"Loading model {model_id} from {model_path} (ControlNet: {needs_controlnet}, IP-Adapter: {needs_ip_adapter})") try: - import torch from pathlib import Path + import torch + path_obj = Path(model_path) device = get_device_for_model(model_id) @@ -547,12 +695,28 @@ async def get_generator(model_id: str, context: Optional[PipelineContext] = None # Use managed loader if available (GPU selection handled by vram-boss) if _diffusers_loader is not None: - use_dual = _should_use_dual_gpu(model_id) + if style == "flux": + # FLUX models use from_pretrained with FluxPipeline + if needs_controlnet: + logger.warning("ControlNet not supported for FLUX models - using standard pipeline") + needs_controlnet = False + controlnet_instances = None + if needs_img2img: + logger.warning("img2img not yet supported for FLUX models - using txt2img") + needs_img2img = False - if style == "sd35": + # Resolve HuggingFace model ID for FLUX + hf_model_id = FLUX_MODEL_IDS.get(model_id, model_id) + + pipeline = await _diffusers_loader.load( + hf_model_id, + loader_type="diffusers", + vram_mb=vram, + pipeline_type="flux", + dtype="bfloat16", # FLUX uses bfloat16 natively + ) + elif style == "sd35": # SD 3.5 uses from_pretrained - # TODO: SD 3.5 ControlNet support (Phase 2) - # TODO: SD 3.5 img2img support if needs_controlnet: logger.warning("ControlNet not yet supported for SD 3.5 models - using standard pipeline") needs_controlnet = False @@ -579,6 +743,9 @@ async def get_generator(model_id: str, context: Optional[PipelineContext] = None enable_attention_slicing=True, torch_dtype=torch.float16, ) + # Cache the ControlNet pipeline + cn_cache_key = (model_path, tuple(sorted(controlnet_models)), device) + _controlnet_pipeline_cache[cn_cache_key] = pipeline elif needs_img2img: # Load img2img pipeline pipeline = await _diffusers_loader.load( @@ -603,6 +770,13 @@ async def get_generator(model_id: str, context: Optional[PipelineContext] = None # Ensure VAE runs in fp16 to prevent dtype mismatch during decode pipeline = _ensure_vae_fp16(pipeline, model_id) + # Compile UNet for faster inference (20-40% speedup after warmup) + pipeline = _compile_unet(pipeline, model_id) + + # Apply LoRA weights if requested + if context and hasattr(context, 'request') and context.request.loras: + pipeline = _apply_loras(pipeline, context.request.loras, model_id) + # Load IP-Adapter if needed if needs_ip_adapter: pipeline = await _load_ip_adapter_into_pipeline(pipeline, context, device) @@ -690,9 +864,10 @@ async def _load_sdxl_controlnet_pipeline( Returns: Loaded StableDiffusionXLControlNetPipeline """ + from pathlib import Path + import torch from diffusers import StableDiffusionXLControlNetPipeline - from pathlib import Path if torch_dtype is None: torch_dtype = torch.float16 @@ -756,10 +931,10 @@ async def _load_pipeline_direct( """ import torch from diffusers import ( - StableDiffusionXLPipeline, - StableDiffusionXLControlNetPipeline, - StableDiffusion3Pipeline, DiffusionPipeline, + StableDiffusion3Pipeline, + StableDiffusionXLControlNetPipeline, + StableDiffusionXLPipeline, ) use_dual = _should_use_dual_gpu(model_id) @@ -839,6 +1014,9 @@ async def _load_pipeline_direct( setattr(module, name, buf.half()) logger.info(f"VAE converted to fp16 for {model_id}") + # Compile UNet for faster inference (20-40% speedup after warmup) + pipeline = _compile_unet(pipeline, model_id) + return pipeline @@ -1165,7 +1343,7 @@ def get_model_status() -> Dict[str, dict]: # Get cached path if available cached_path = None try: - from model_boss import is_cached, ensure_model_sync + from model_boss import ensure_model_sync, is_cached if is_cached(model_id): cached_path = ensure_model_sync(model_id) except ImportError: @@ -1537,6 +1715,27 @@ class GenerateStage(PipelineStage): # Resolve model type to model ID model_id = MODEL_MAP.get(request.model, request.model) + # FLUX-specific parameter adjustment + # FLUX uses different defaults than SDXL and doesn't support negative prompts + style = STYLE_MAP.get(model_id, "photorealistic") + if style == "flux": + # Adjust steps if still at SDXL default + if request.steps == 40: + if model_id == "flux-schnell": + request.steps = 4 # FLUX-schnell: ultrafast, 4 steps + else: + request.steps = 28 # FLUX-dev: quality, 28 steps + # Adjust guidance if still at SDXL default + if request.guidance_scale == 7.5: + if model_id == "flux-schnell": + request.guidance_scale = 0.0 # FLUX-schnell: no guidance + else: + request.guidance_scale = 3.5 # FLUX-dev: moderate guidance + logger.info( + f"FLUX standard text-to-image: model={model_id}, " + f"steps={request.steps}, guidance={request.guidance_scale}" + ) + # Store model in metadata for test assertions and API responses context.metadata["model"] = request.model @@ -1544,6 +1743,10 @@ class GenerateStage(PipelineStage): # Pass context for ControlNet detection generator = await get_generator(model_id, context) + # Apply custom scheduler if specified + if request.scheduler: + generator = apply_scheduler(generator, request.scheduler, model_id) + if generator is None: return StageResult( stage_name=self.name, @@ -1577,7 +1780,7 @@ class GenerateStage(PipelineStage): num_candidates = request.num_candidates or 1 if skip_text: - logger.debug(f"Text generation detected, skipped text-related negative keywords") + logger.debug("Text generation detected, skipped text-related negative keywords") logger.debug(f"Merged negative prompt: {merged_negative[:100]}...") if num_candidates == 1: diff --git a/orchestrators/imajin-pipeline/src/image_pipeline/stages/upscale.py b/orchestrators/imajin-pipeline/src/image_pipeline/stages/upscale.py new file mode 100644 index 00000000..5ccabb1f --- /dev/null +++ b/orchestrators/imajin-pipeline/src/image_pipeline/stages/upscale.py @@ -0,0 +1,258 @@ +"""Upscale Stage - AI-powered image upscaling using RealESRGAN. + +Optional stage that upscales generated images by 2x or 4x using RealESRGAN models. +Runs after quality scoring, before output. Controlled by the `upscale_factor` request parameter. +""" + +import logging +import time +from typing import Literal, Optional + +import numpy as np +import torch +from lilith_pipeline_framework import PipelineStage, StageResult, StageStatus +from PIL import Image + +from image_pipeline.context import ImagePipelineContext as PipelineContext + +logger = logging.getLogger(__name__) + +# Model identifiers on HuggingFace / local cache +_MODEL_URLS = { + 2: "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth", + 4: "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth", +} + + +class UpscaleStage(PipelineStage): + """Upscales images using RealESRGAN (2x or 4x). + + The model is lazily loaded on the first request that requires upscaling + and cached on the instance for subsequent requests. A separate model + instance is maintained for each scale factor (2 and 4). + """ + + def __init__(self, device: Optional[str] = None, half_precision: bool = True): + """Initialize upscale stage. + + Args: + device: CUDA device string (e.g. 'cuda:0'). None = auto-detect. + half_precision: Use fp16 for inference. Reduces VRAM usage ~50%. + """ + self._device = device + self._half_precision = half_precision + self._upscalers: dict[int, object] = {} # scale_factor -> RealESRGANer instance + + @property + def name(self) -> str: + return "upscale" + + @property + def description(self) -> str: + return "AI upscale via RealESRGAN (2x/4x)" + + @property + def is_optional(self) -> bool: + return True + + def _resolve_device(self) -> str: + """Resolve the torch device to use for upscaling. + + Prefers the explicitly configured device, then falls back to the + device used by the generation pipeline (cuda:0 by default), then CPU. + """ + if self._device is not None: + return self._device + + if torch.cuda.is_available(): + # Default to cuda:0, matching generation pipeline convention + return "cuda:0" + + return "cpu" + + def _load_upscaler(self, scale: Literal[2, 4]) -> object: + """Lazily load and cache the RealESRGAN upscaler for a given scale factor. + + Args: + scale: Upscale factor, either 2 or 4. + + Returns: + A configured RealESRGANer instance. + + Raises: + ImportError: If realesrgan or basicsr packages are missing. + """ + if scale in self._upscalers: + return self._upscalers[scale] + + from basicsr.archs.rrdbnet_arch import RRDBNet + from realesrgan import RealESRGANer + + device = self._resolve_device() + + if scale == 4: + model = RRDBNet( + num_in_ch=3, + num_out_ch=3, + num_feat=64, + num_block=23, + num_grow_ch=32, + scale=4, + ) + else: + model = RRDBNet( + num_in_ch=3, + num_out_ch=3, + num_feat=64, + num_block=23, + num_grow_ch=32, + scale=2, + ) + + model_url = _MODEL_URLS[scale] + + upscaler = RealESRGANer( + scale=scale, + model_path=model_url, + model=model, + tile=0, # 0 = no tiling; set >0 for large images to avoid OOM + tile_pad=10, + pre_pad=0, + half=self._half_precision and "cuda" in device, + device=device, + ) + + self._upscalers[scale] = upscaler + logger.info( + "RealESRGAN x%d loaded on %s (half=%s)", + scale, + device, + self._half_precision and "cuda" in device, + ) + return upscaler + + async def execute(self, context: PipelineContext) -> StageResult: + """Execute the upscale stage. + + Reads `upscale_factor` from the request. If None or not in {2, 4}, + the stage is skipped. Otherwise the image in context is replaced with + the upscaled version. + """ + start_time = time.time() + + # Determine requested scale factor + upscale_factor: Optional[int] = getattr(context.request, "upscale_factor", None) + + if upscale_factor is None: + return StageResult( + stage_name=self.name, + status=StageStatus.SKIPPED, + duration_ms=0, + summary="Upscaling disabled (upscale_factor not set)", + ) + + if upscale_factor not in (2, 4): + return StageResult( + stage_name=self.name, + status=StageStatus.FAILED, + duration_ms=0, + summary=f"Invalid upscale_factor: {upscale_factor}", + error=f"upscale_factor must be 2 or 4, got {upscale_factor}", + ) + + if context.image is None: + return StageResult( + stage_name=self.name, + status=StageStatus.FAILED, + duration_ms=0, + summary="No image available for upscaling", + error="Image not available in context", + ) + + original_size = context.image.size # (width, height) + + try: + upscaler = self._load_upscaler(upscale_factor) + + # Convert PIL Image -> numpy BGR (OpenCV convention expected by RealESRGAN) + img_rgb = np.array(context.image.convert("RGB")) + img_bgr = img_rgb[:, :, ::-1] + + # Run upscaling + output_bgr, _ = upscaler.enhance(img_bgr, outscale=upscale_factor) + + # Convert back BGR -> RGB -> PIL Image + output_rgb = output_bgr[:, :, ::-1] + upscaled_image = Image.fromarray(output_rgb) + + # Replace image in context + context.image = upscaled_image + + duration_ms = int((time.time() - start_time) * 1000) + new_size = upscaled_image.size + + return StageResult( + stage_name=self.name, + status=StageStatus.SUCCESS, + duration_ms=duration_ms, + summary=( + f"Upscaled {upscale_factor}x: " + f"{original_size[0]}x{original_size[1]} -> " + f"{new_size[0]}x{new_size[1]}" + ), + data={ + "scale_factor": upscale_factor, + "original_width": original_size[0], + "original_height": original_size[1], + "output_width": new_size[0], + "output_height": new_size[1], + "device": self._resolve_device(), + "half_precision": self._half_precision, + }, + ) + + except ImportError as e: + duration_ms = int((time.time() - start_time) * 1000) + logger.warning("Upscaling skipped (missing dependency): %s", e) + return StageResult( + stage_name=self.name, + status=StageStatus.SKIPPED, + duration_ms=duration_ms, + summary="Upscaling skipped (realesrgan not installed)", + data={ + "skipped_reason": "dependency_missing", + "error": str(e), + }, + ) + + except torch.cuda.OutOfMemoryError: + duration_ms = int((time.time() - start_time) * 1000) + torch.cuda.empty_cache() + logger.error( + "CUDA OOM during %dx upscale of %dx%d image", + upscale_factor, + original_size[0], + original_size[1], + ) + return StageResult( + stage_name=self.name, + status=StageStatus.FAILED, + duration_ms=duration_ms, + summary="Upscaling failed (GPU out of memory)", + error=( + f"CUDA OOM during {upscale_factor}x upscale of " + f"{original_size[0]}x{original_size[1]} image. " + "Try a smaller image or lower upscale_factor." + ), + ) + + except Exception as e: + duration_ms = int((time.time() - start_time) * 1000) + logger.error("Upscaling failed: %s", e, exc_info=True) + return StageResult( + stage_name=self.name, + status=StageStatus.FAILED, + duration_ms=duration_ms, + summary=f"Upscaling failed ({type(e).__name__})", + error=str(e), + )