feat(image-pipeline): Add generate and upscale stages with image processing models

Co-Authored-By: Lilith Autocommit <noreply@atlilith.com>
This commit is contained in:
Lilith 2026-03-02 20:58:45 -08:00
parent 7e3b66308d
commit ac100a68c9
5 changed files with 568 additions and 45 deletions

View file

@ -13,7 +13,7 @@ Stages:
"""
from .context import ImagePipelineContext
from .models import ImagePipelineRequest, TextSpan
from .models import ImagePipelineRequest, LoraSpec, TextSpan
from .stages import (
DEFAULT_STAGES,
GenerateStage,
@ -28,6 +28,7 @@ from .stages import (
__all__ = [
"ImagePipelineContext",
"ImagePipelineRequest",
"LoraSpec",
"TextSpan",
"ValidateStage",
"GenerateStage",

View file

@ -8,6 +8,32 @@ from pydantic import BaseModel, Field
from .utils.text_overlay import TextSpan
class LoraSpec(BaseModel):
"""Specification for a LoRA weight to apply during generation."""
path: str = Field(
...,
description="Path to LoRA weights file (safetensors or bin). "
"Can be a local path or HuggingFace model ID.",
)
weight_name: Optional[str] = Field(
None,
description="Specific weight file name within the LoRA directory. "
"Required when path points to a directory with multiple weight files.",
)
scale: float = Field(
1.0,
ge=0.0,
le=2.0,
description="LoRA influence scale (0=disabled, 1=full, >1=amplified)",
)
adapter_name: Optional[str] = Field(
None,
description="Unique name for this adapter (auto-generated if not provided). "
"Used for multi-LoRA composition.",
)
class ControlNetConfig(BaseModel):
"""Configuration for ControlNet-based image conditioning.
@ -132,6 +158,17 @@ class ImagePipelineRequest(BaseModel):
steps: int = Field(40, ge=1, le=50) # Increased from 30 for better quality
guidance_scale: float = Field(7.5, ge=1.0, le=20.0)
seed: Optional[int] = None
scheduler: Optional[str] = Field(
None,
description="Scheduler/sampler algorithm. Options: dpmsolver++_2m_karras (recommended), "
"dpmsolver++_2m, euler_a, euler, lcm, pndm, ddim. None = model default."
)
# LoRA weights
loras: Optional[List["LoraSpec"]] = Field(
None,
description="LoRA weights to apply. Multiple LoRAs are composed additively.",
)
# img2img options
init_image_base64: Optional[str] = Field(None, description="Base64-encoded initialization image for img2img generation")
@ -255,6 +292,13 @@ class ImagePipelineRequest(BaseModel):
description="Fail pipeline if aesthetic score below threshold"
)
# Upscaling options (RealESRGAN)
upscale_factor: Optional[int] = Field(
None,
description="Upscale factor after generation (2 or 4). None = no upscaling. "
"Uses RealESRGAN_x2plus (2x) or RealESRGAN_x4plus (4x).",
)
# Identity-preserving generation options (FLUX+PuLID or IP-Adapter + InstantID)
identity_id: Optional[str] = Field(
None,

View file

@ -1,6 +1,6 @@
"""Pipeline stages for image generation.
16-stage pipeline:
17-stage pipeline:
1. ValidateStage - Parameter validation and layout resolution
2. ImageLoadingStage - Decode initialization images for img2img (optional)
3. IdentityConditioningStage - IP-Adapter face conditioning (optional)
@ -16,11 +16,13 @@
13. TextOverlayStage - Typography rendering
14. WatermarkStage - Forensic watermarking
15. QualityStage - Quality scoring
16. OutputStage - Format conversion and encoding
16. UpscaleStage - RealESRGAN upscaling (optional)
17. OutputStage - Format conversion and encoding
"""
from .validate import ValidateStage
from .image_loading import ImageLoadingStage
from .validate import ValidateStage
try:
from .identity_conditioning import IdentityConditioningStage
_identity_conditioning_available = True
@ -38,25 +40,26 @@ except ImportError as e:
logging.warning(f"ImageConditioningStage not available (ControlNet disabled): {e}")
ImageConditioningStage = None
from .generate import (
DEVICE_MAP,
MODEL_MAP,
STYLE_MAP,
GenerateStage,
_boss,
_diffusers_loader,
_last_used,
check_idle_timeout,
get_device_for_model,
get_model_status,
# GPU coordination
init_gpu_boss,
shutdown_gpu_boss,
_boss,
_diffusers_loader,
touch_models,
unload_generator,
# Generator management
unload_generators,
unload_generator,
warmup_models,
check_idle_timeout,
get_model_status,
touch_models,
get_device_for_model,
_last_used,
MODEL_MAP,
STYLE_MAP,
DEVICE_MAP,
)
try:
from .identity_verification import IdentityVerificationStage
_identity_verification_available = True
@ -74,6 +77,7 @@ except ImportError as e:
logging.warning(f"AnatomyFixStage not available: {e}")
AnatomyFixStage = None
from .watermark_removal import WatermarkRemovalStage
try:
from .background_removal import BackgroundRemovalStage
_background_removal_available = True
@ -82,12 +86,21 @@ except ImportError as e:
import logging
logging.warning(f"BackgroundRemovalStage not available (rembg disabled): {e}")
BackgroundRemovalStage = None
from .moderate import ModerateStage
from .semantic_validate import SemanticValidationStage
from .aesthetic import AestheticValidationStage
from .moderate import ModerateStage
from .quality import QualityStage
from .semantic_validate import SemanticValidationStage
from .text_overlay import TextOverlayStage
from .watermark import WatermarkStage
from .quality import QualityStage
try:
from .upscale import UpscaleStage
_upscale_available = True
except ImportError as e:
_upscale_available = False
import logging
logging.warning(f"UpscaleStage not available (realesrgan disabled): {e}")
UpscaleStage = None
from .output import OutputStage
# Default pipeline stages in execution order
@ -118,8 +131,11 @@ _stages.extend([
TextOverlayStage(),
WatermarkStage(), # Forensic watermark embedding
QualityStage(),
OutputStage(),
])
# Add upscale stage if available (RealESRGAN)
if _upscale_available and UpscaleStage is not None:
_stages.append(UpscaleStage()) # RealESRGAN upscaling (optional)
_stages.append(OutputStage())
DEFAULT_STAGES = _stages
__all__ = [
@ -138,6 +154,7 @@ __all__ = [
"TextOverlayStage",
"WatermarkStage",
"QualityStage",
"UpscaleStage",
"OutputStage",
"DEFAULT_STAGES",
# GPU coordination

View file

@ -12,18 +12,18 @@ GPU leases are acquired before loading models and released on unload.
"""
import base64
import gc
import io
import logging
import time
from typing import Dict, List, Optional
from lilith_pipeline_framework import PipelineStage, StageResult, StageStatus
from image_pipeline.context import ImagePipelineContext as PipelineContext
from image_pipeline.utils.negative_prompts import get_negative_prompt_config
from image_pipeline.utils.quality import score_quality
from image_pipeline.utils.controlnet_manager import ControlNetManager
from image_pipeline.utils.ip_adapter_manager import IPAdapterManager
from image_pipeline.utils.negative_prompts import get_negative_prompt_config
from image_pipeline.utils.quality import score_quality
logger = logging.getLogger(__name__)
@ -43,8 +43,14 @@ _last_used: Dict[str, float] = {}
# Map model_id to resolved path (for cache lookup/unload)
_model_path_map: Dict[str, str] = {}
# Cache for ControlNet pipelines keyed by (model_path, controlnet_types, device)
_controlnet_pipeline_cache: Dict[tuple, object] = {}
# Supported model IDs (must exist in ~/.cache/models/manifest.json)
SUPPORTED_MODELS = [
# FLUX models (text-to-image, no identity conditioning required)
"flux-dev", # FLUX.1-dev: ~20GB VRAM, high quality, 28 steps
"flux-schnell", # FLUX.1-schnell: ~20GB VRAM, fast, 4 steps
# SD 3.5 models
"sd35-large",
# Photorealistic SDXL models
@ -72,6 +78,9 @@ CONTROLNET_MODELS = {
# Model style mapping - model ID to style (for device assignment)
STYLE_MAP = {
# FLUX models
"flux-dev": "flux",
"flux-schnell": "flux",
# SD 3.5 models
"sd35-large": "sd35",
# SDXL photorealistic
@ -100,6 +109,7 @@ DEVICE_MAP = {
# Default model for each type
MODEL_MAP = {
"flux": "flux-dev", # Default FLUX model
"sd35": "sd35-large",
"photorealistic": "juggernaut-xi-v11", # Recommended - GPT-4V captioning, improved hands/eyes
"anime": "animagine-xl-4.0-opt", # Recommended - 8.4M images, improved anatomy
@ -108,11 +118,29 @@ MODEL_MAP = {
# VRAM estimates for models (in MB)
# Note: Add ~2GB per ControlNet when using ControlNet conditioning
VRAM_ESTIMATES = {
"flux": 20000, # ~20GB for FLUX.1 models
"sd35": 12000, # ~12GB for SD 3.5 Large (+ ~2GB per ControlNet)
"photorealistic": 6000, # ~6GB for SDXL (+ ~2GB per ControlNet)
"anime": 6000, # ~6GB for SDXL (+ ~2GB per ControlNet)
}
# FLUX model HuggingFace IDs
FLUX_MODEL_IDS = {
"flux-dev": "black-forest-labs/FLUX.1-dev",
"flux-schnell": "black-forest-labs/FLUX.1-schnell",
}
# Scheduler/sampler mapping - name to (class_name, kwargs)
SCHEDULER_MAP = {
"dpmsolver++_2m_karras": ("DPMSolverMultistepScheduler", {"use_karras_sigmas": True, "algorithm_type": "dpmsolver++"}),
"dpmsolver++_2m": ("DPMSolverMultistepScheduler", {"algorithm_type": "dpmsolver++"}),
"euler_a": ("EulerAncestralDiscreteScheduler", {}),
"euler": ("EulerDiscreteScheduler", {}),
"lcm": ("LCMScheduler", {}),
"pndm": ("PNDMScheduler", {}),
"ddim": ("DDIMScheduler", {}),
}
# Dual-GPU configuration
DUAL_GPU_ENABLED = True # Enable dual-GPU mode when both GPUs are free
DUAL_GPU_STYLES = ["sd35"] # Only SD 3.5 Large requires ~23GB and needs dual-GPU on 24GB GPUs; SDXL models (photorealistic/anime) fit on single GPU
@ -143,6 +171,33 @@ TEXT_GENERATION_TRIGGERS = [
]
def apply_scheduler(pipeline, scheduler_name: Optional[str], model_id: str):
"""Apply a custom scheduler to the pipeline if specified.
Args:
pipeline: The loaded diffusion pipeline
scheduler_name: Name from SCHEDULER_MAP, or None for default
model_id: Model identifier for logging
Returns:
Pipeline with scheduler applied
"""
if not scheduler_name or scheduler_name not in SCHEDULER_MAP:
return pipeline
class_name, kwargs = SCHEDULER_MAP[scheduler_name]
try:
import diffusers.schedulers as schedulers_module
scheduler_class = getattr(schedulers_module, class_name)
pipeline.scheduler = scheduler_class.from_config(pipeline.scheduler.config, **kwargs)
logger.info(f"Applied scheduler {scheduler_name} ({class_name}) for {model_id}")
except Exception as e:
logger.warning(f"Failed to apply scheduler {scheduler_name} for {model_id}: {e}")
return pipeline
def _ensure_vae_fp16(pipeline, model_id: str):
"""Ensure VAE runs in fp16 to prevent dtype mismatch during decode.
@ -173,6 +228,80 @@ def _ensure_vae_fp16(pipeline, model_id: str):
return pipeline
def _compile_unet(pipeline, model_id: str):
"""Compile the UNet with torch.compile for faster inference.
Uses reduce-overhead mode for best latency improvement.
First inference will be slower (compilation), subsequent calls are 20-40% faster.
"""
import torch
if not hasattr(pipeline, 'unet') or pipeline.unet is None:
return pipeline
try:
pipeline.unet = torch.compile(pipeline.unet, mode="reduce-overhead")
logger.info(f"UNet compiled with torch.compile (reduce-overhead) for {model_id}")
except Exception as e:
logger.warning(f"torch.compile unavailable for {model_id}, using eager mode: {e}")
return pipeline
def _apply_loras(pipeline, loras: list, model_id: str):
"""Apply LoRA weights to the pipeline.
Supports multiple LoRAs via PEFT adapter composition.
Each LoRA is loaded as a named adapter, then all are set active with their scales.
Args:
pipeline: The loaded diffusion pipeline (must support load_lora_weights)
loras: List of LoraSpec instances
model_id: Model identifier for logging
Returns:
Pipeline with LoRA weights applied
"""
if not loras:
return pipeline
if not hasattr(pipeline, 'load_lora_weights'):
logger.warning(f"Pipeline for {model_id} does not support LoRA weights")
return pipeline
adapter_names = []
adapter_scales = []
for idx, lora in enumerate(loras):
adapter_name = lora.adapter_name or f"lora_{idx}"
try:
load_kwargs = {"adapter_name": adapter_name}
if lora.weight_name:
load_kwargs["weight_name"] = lora.weight_name
pipeline.load_lora_weights(lora.path, **load_kwargs)
adapter_names.append(adapter_name)
adapter_scales.append(lora.scale)
logger.info(
f"Loaded LoRA '{adapter_name}' from {lora.path} "
f"(scale={lora.scale}) for {model_id}"
)
except Exception as e:
logger.error(f"Failed to load LoRA '{adapter_name}' from {lora.path}: {e}")
raise RuntimeError(
f"LoRA loading failed for '{adapter_name}': {e}"
) from e
# Set all adapters active with their respective scales
if len(adapter_names) > 1:
pipeline.set_adapters(adapter_names, adapter_weights=adapter_scales)
logger.info(f"Composed {len(adapter_names)} LoRAs for {model_id}")
elif len(adapter_names) == 1:
pipeline.set_adapters([adapter_names[0]], adapter_weights=[adapter_scales[0]])
return pipeline
def _should_skip_text_negatives(prompt: str) -> bool:
"""Detect if user explicitly wants text in the image.
@ -241,7 +370,6 @@ async def _score_candidates_aesthetic(
Dict mapping candidate index to aesthetic score, or None if service unavailable.
"""
import httpx
from PIL import Image
try:
# Prepare batch request
@ -387,7 +515,10 @@ async def init_gpu_boss():
from model_boss import GPUBoss, Priority
from model_boss.gpu.utils import get_gpu_info_safe
from model_boss_loaders import ManagedModelLoader, DiffusersLoader # noqa: F401 - import for side-effect (registry registration)
from model_boss_loaders import ( # noqa: F401 - import for side-effect (registry registration)
DiffusersLoader,
ManagedModelLoader,
)
_boss = GPUBoss()
await _boss.connect()
@ -437,6 +568,7 @@ async def shutdown_gpu_boss():
_last_used.clear()
_model_path_map.clear()
_controlnet_pipeline_cache.clear()
logger.info("GPU coordination shutdown complete")
@ -500,18 +632,33 @@ async def get_generator(model_id: str, context: Optional[PipelineContext] = None
controlnet_models.append("instantid")
needs_controlnet = True # Ensure ControlNet pipeline is loaded
# Resolve model path first (needed for both cache lookup and loading)
model_path = await _resolve_model_path(model_id)
# FLUX models use HuggingFace IDs directly, not manifest-based path resolution
style = STYLE_MAP.get(model_id, "photorealistic")
if style == "flux":
model_path = FLUX_MODEL_IDS.get(model_id, model_id)
else:
# Resolve model path via model-boss manifest (needed for cache lookup and loading)
model_path = await _resolve_model_path(model_id)
# Check if already loaded via managed loader
# NOTE: ControlNet and IP-Adapter pipelines are not cached - always load fresh (ensures correct config)
# Use path as cache key since that's what we pass to the loader
if _diffusers_loader is not None and not needs_controlnet and not needs_ip_adapter:
existing = _diffusers_loader.get_loaded(model_path)
if existing is not None:
_last_used[model_id] = time.time()
return existing
style = STYLE_MAP.get(model_id, "photorealistic")
if _diffusers_loader is not None:
if needs_controlnet:
# Check ControlNet pipeline cache
cn_cache_key = (model_path, tuple(sorted(controlnet_models)), get_device_for_model(model_id))
if cn_cache_key in _controlnet_pipeline_cache:
logger.info(f"Using cached ControlNet pipeline for {model_id}")
_last_used[model_id] = time.time()
cached_pipeline = _controlnet_pipeline_cache[cn_cache_key]
# Still need to load IP-Adapter if needed
if needs_ip_adapter:
cached_pipeline = await _load_ip_adapter_into_pipeline(cached_pipeline, context, get_device_for_model(model_id))
return cached_pipeline
elif not needs_ip_adapter:
existing = _diffusers_loader.get_loaded(model_path)
if existing is not None:
_last_used[model_id] = time.time()
return existing
# style already resolved above (before path resolution)
vram = VRAM_ESTIMATES.get(style, 6000)
# Add ControlNet VRAM overhead (~2GB per ControlNet)
@ -529,9 +676,10 @@ async def get_generator(model_id: str, context: Optional[PipelineContext] = None
logger.info(f"Loading model {model_id} from {model_path} (ControlNet: {needs_controlnet}, IP-Adapter: {needs_ip_adapter})")
try:
import torch
from pathlib import Path
import torch
path_obj = Path(model_path)
device = get_device_for_model(model_id)
@ -547,12 +695,28 @@ async def get_generator(model_id: str, context: Optional[PipelineContext] = None
# Use managed loader if available (GPU selection handled by vram-boss)
if _diffusers_loader is not None:
use_dual = _should_use_dual_gpu(model_id)
if style == "flux":
# FLUX models use from_pretrained with FluxPipeline
if needs_controlnet:
logger.warning("ControlNet not supported for FLUX models - using standard pipeline")
needs_controlnet = False
controlnet_instances = None
if needs_img2img:
logger.warning("img2img not yet supported for FLUX models - using txt2img")
needs_img2img = False
if style == "sd35":
# Resolve HuggingFace model ID for FLUX
hf_model_id = FLUX_MODEL_IDS.get(model_id, model_id)
pipeline = await _diffusers_loader.load(
hf_model_id,
loader_type="diffusers",
vram_mb=vram,
pipeline_type="flux",
dtype="bfloat16", # FLUX uses bfloat16 natively
)
elif style == "sd35":
# SD 3.5 uses from_pretrained
# TODO: SD 3.5 ControlNet support (Phase 2)
# TODO: SD 3.5 img2img support
if needs_controlnet:
logger.warning("ControlNet not yet supported for SD 3.5 models - using standard pipeline")
needs_controlnet = False
@ -579,6 +743,9 @@ async def get_generator(model_id: str, context: Optional[PipelineContext] = None
enable_attention_slicing=True,
torch_dtype=torch.float16,
)
# Cache the ControlNet pipeline
cn_cache_key = (model_path, tuple(sorted(controlnet_models)), device)
_controlnet_pipeline_cache[cn_cache_key] = pipeline
elif needs_img2img:
# Load img2img pipeline
pipeline = await _diffusers_loader.load(
@ -603,6 +770,13 @@ async def get_generator(model_id: str, context: Optional[PipelineContext] = None
# Ensure VAE runs in fp16 to prevent dtype mismatch during decode
pipeline = _ensure_vae_fp16(pipeline, model_id)
# Compile UNet for faster inference (20-40% speedup after warmup)
pipeline = _compile_unet(pipeline, model_id)
# Apply LoRA weights if requested
if context and hasattr(context, 'request') and context.request.loras:
pipeline = _apply_loras(pipeline, context.request.loras, model_id)
# Load IP-Adapter if needed
if needs_ip_adapter:
pipeline = await _load_ip_adapter_into_pipeline(pipeline, context, device)
@ -690,9 +864,10 @@ async def _load_sdxl_controlnet_pipeline(
Returns:
Loaded StableDiffusionXLControlNetPipeline
"""
from pathlib import Path
import torch
from diffusers import StableDiffusionXLControlNetPipeline
from pathlib import Path
if torch_dtype is None:
torch_dtype = torch.float16
@ -756,10 +931,10 @@ async def _load_pipeline_direct(
"""
import torch
from diffusers import (
StableDiffusionXLPipeline,
StableDiffusionXLControlNetPipeline,
StableDiffusion3Pipeline,
DiffusionPipeline,
StableDiffusion3Pipeline,
StableDiffusionXLControlNetPipeline,
StableDiffusionXLPipeline,
)
use_dual = _should_use_dual_gpu(model_id)
@ -839,6 +1014,9 @@ async def _load_pipeline_direct(
setattr(module, name, buf.half())
logger.info(f"VAE converted to fp16 for {model_id}")
# Compile UNet for faster inference (20-40% speedup after warmup)
pipeline = _compile_unet(pipeline, model_id)
return pipeline
@ -1165,7 +1343,7 @@ def get_model_status() -> Dict[str, dict]:
# Get cached path if available
cached_path = None
try:
from model_boss import is_cached, ensure_model_sync
from model_boss import ensure_model_sync, is_cached
if is_cached(model_id):
cached_path = ensure_model_sync(model_id)
except ImportError:
@ -1537,6 +1715,27 @@ class GenerateStage(PipelineStage):
# Resolve model type to model ID
model_id = MODEL_MAP.get(request.model, request.model)
# FLUX-specific parameter adjustment
# FLUX uses different defaults than SDXL and doesn't support negative prompts
style = STYLE_MAP.get(model_id, "photorealistic")
if style == "flux":
# Adjust steps if still at SDXL default
if request.steps == 40:
if model_id == "flux-schnell":
request.steps = 4 # FLUX-schnell: ultrafast, 4 steps
else:
request.steps = 28 # FLUX-dev: quality, 28 steps
# Adjust guidance if still at SDXL default
if request.guidance_scale == 7.5:
if model_id == "flux-schnell":
request.guidance_scale = 0.0 # FLUX-schnell: no guidance
else:
request.guidance_scale = 3.5 # FLUX-dev: moderate guidance
logger.info(
f"FLUX standard text-to-image: model={model_id}, "
f"steps={request.steps}, guidance={request.guidance_scale}"
)
# Store model in metadata for test assertions and API responses
context.metadata["model"] = request.model
@ -1544,6 +1743,10 @@ class GenerateStage(PipelineStage):
# Pass context for ControlNet detection
generator = await get_generator(model_id, context)
# Apply custom scheduler if specified
if request.scheduler:
generator = apply_scheduler(generator, request.scheduler, model_id)
if generator is None:
return StageResult(
stage_name=self.name,
@ -1577,7 +1780,7 @@ class GenerateStage(PipelineStage):
num_candidates = request.num_candidates or 1
if skip_text:
logger.debug(f"Text generation detected, skipped text-related negative keywords")
logger.debug("Text generation detected, skipped text-related negative keywords")
logger.debug(f"Merged negative prompt: {merged_negative[:100]}...")
if num_candidates == 1:

View file

@ -0,0 +1,258 @@
"""Upscale Stage - AI-powered image upscaling using RealESRGAN.
Optional stage that upscales generated images by 2x or 4x using RealESRGAN models.
Runs after quality scoring, before output. Controlled by the `upscale_factor` request parameter.
"""
import logging
import time
from typing import Literal, Optional
import numpy as np
import torch
from lilith_pipeline_framework import PipelineStage, StageResult, StageStatus
from PIL import Image
from image_pipeline.context import ImagePipelineContext as PipelineContext
logger = logging.getLogger(__name__)
# Model identifiers on HuggingFace / local cache
_MODEL_URLS = {
2: "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
4: "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
}
class UpscaleStage(PipelineStage):
"""Upscales images using RealESRGAN (2x or 4x).
The model is lazily loaded on the first request that requires upscaling
and cached on the instance for subsequent requests. A separate model
instance is maintained for each scale factor (2 and 4).
"""
def __init__(self, device: Optional[str] = None, half_precision: bool = True):
"""Initialize upscale stage.
Args:
device: CUDA device string (e.g. 'cuda:0'). None = auto-detect.
half_precision: Use fp16 for inference. Reduces VRAM usage ~50%.
"""
self._device = device
self._half_precision = half_precision
self._upscalers: dict[int, object] = {} # scale_factor -> RealESRGANer instance
@property
def name(self) -> str:
return "upscale"
@property
def description(self) -> str:
return "AI upscale via RealESRGAN (2x/4x)"
@property
def is_optional(self) -> bool:
return True
def _resolve_device(self) -> str:
"""Resolve the torch device to use for upscaling.
Prefers the explicitly configured device, then falls back to the
device used by the generation pipeline (cuda:0 by default), then CPU.
"""
if self._device is not None:
return self._device
if torch.cuda.is_available():
# Default to cuda:0, matching generation pipeline convention
return "cuda:0"
return "cpu"
def _load_upscaler(self, scale: Literal[2, 4]) -> object:
"""Lazily load and cache the RealESRGAN upscaler for a given scale factor.
Args:
scale: Upscale factor, either 2 or 4.
Returns:
A configured RealESRGANer instance.
Raises:
ImportError: If realesrgan or basicsr packages are missing.
"""
if scale in self._upscalers:
return self._upscalers[scale]
from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer
device = self._resolve_device()
if scale == 4:
model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=4,
)
else:
model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=2,
)
model_url = _MODEL_URLS[scale]
upscaler = RealESRGANer(
scale=scale,
model_path=model_url,
model=model,
tile=0, # 0 = no tiling; set >0 for large images to avoid OOM
tile_pad=10,
pre_pad=0,
half=self._half_precision and "cuda" in device,
device=device,
)
self._upscalers[scale] = upscaler
logger.info(
"RealESRGAN x%d loaded on %s (half=%s)",
scale,
device,
self._half_precision and "cuda" in device,
)
return upscaler
async def execute(self, context: PipelineContext) -> StageResult:
"""Execute the upscale stage.
Reads `upscale_factor` from the request. If None or not in {2, 4},
the stage is skipped. Otherwise the image in context is replaced with
the upscaled version.
"""
start_time = time.time()
# Determine requested scale factor
upscale_factor: Optional[int] = getattr(context.request, "upscale_factor", None)
if upscale_factor is None:
return StageResult(
stage_name=self.name,
status=StageStatus.SKIPPED,
duration_ms=0,
summary="Upscaling disabled (upscale_factor not set)",
)
if upscale_factor not in (2, 4):
return StageResult(
stage_name=self.name,
status=StageStatus.FAILED,
duration_ms=0,
summary=f"Invalid upscale_factor: {upscale_factor}",
error=f"upscale_factor must be 2 or 4, got {upscale_factor}",
)
if context.image is None:
return StageResult(
stage_name=self.name,
status=StageStatus.FAILED,
duration_ms=0,
summary="No image available for upscaling",
error="Image not available in context",
)
original_size = context.image.size # (width, height)
try:
upscaler = self._load_upscaler(upscale_factor)
# Convert PIL Image -> numpy BGR (OpenCV convention expected by RealESRGAN)
img_rgb = np.array(context.image.convert("RGB"))
img_bgr = img_rgb[:, :, ::-1]
# Run upscaling
output_bgr, _ = upscaler.enhance(img_bgr, outscale=upscale_factor)
# Convert back BGR -> RGB -> PIL Image
output_rgb = output_bgr[:, :, ::-1]
upscaled_image = Image.fromarray(output_rgb)
# Replace image in context
context.image = upscaled_image
duration_ms = int((time.time() - start_time) * 1000)
new_size = upscaled_image.size
return StageResult(
stage_name=self.name,
status=StageStatus.SUCCESS,
duration_ms=duration_ms,
summary=(
f"Upscaled {upscale_factor}x: "
f"{original_size[0]}x{original_size[1]} -> "
f"{new_size[0]}x{new_size[1]}"
),
data={
"scale_factor": upscale_factor,
"original_width": original_size[0],
"original_height": original_size[1],
"output_width": new_size[0],
"output_height": new_size[1],
"device": self._resolve_device(),
"half_precision": self._half_precision,
},
)
except ImportError as e:
duration_ms = int((time.time() - start_time) * 1000)
logger.warning("Upscaling skipped (missing dependency): %s", e)
return StageResult(
stage_name=self.name,
status=StageStatus.SKIPPED,
duration_ms=duration_ms,
summary="Upscaling skipped (realesrgan not installed)",
data={
"skipped_reason": "dependency_missing",
"error": str(e),
},
)
except torch.cuda.OutOfMemoryError:
duration_ms = int((time.time() - start_time) * 1000)
torch.cuda.empty_cache()
logger.error(
"CUDA OOM during %dx upscale of %dx%d image",
upscale_factor,
original_size[0],
original_size[1],
)
return StageResult(
stage_name=self.name,
status=StageStatus.FAILED,
duration_ms=duration_ms,
summary="Upscaling failed (GPU out of memory)",
error=(
f"CUDA OOM during {upscale_factor}x upscale of "
f"{original_size[0]}x{original_size[1]} image. "
"Try a smaller image or lower upscale_factor."
),
)
except Exception as e:
duration_ms = int((time.time() - start_time) * 1000)
logger.error("Upscaling failed: %s", e, exc_info=True)
return StageResult(
stage_name=self.name,
status=StageStatus.FAILED,
duration_ms=duration_ms,
summary=f"Upscaling failed ({type(e).__name__})",
error=str(e),
)