chore(stages): ⚡ Optimize background removal pipeline performance via enhanced algorithm efficiency and memory optimization
Co-Authored-By: Lilith Autocommit <noreply@atlilith.com>
This commit is contained in:
parent
ca6758fb42
commit
c33ad53be1
2 changed files with 188 additions and 1 deletions
|
|
@ -73,6 +73,14 @@ except ImportError as e:
|
|||
logging.warning(f"AnatomyFixStage not available: {e}")
|
||||
AnatomyFixStage = None
|
||||
from .watermark_removal import WatermarkRemovalStage
|
||||
try:
|
||||
from .background_removal import BackgroundRemovalStage
|
||||
_background_removal_available = True
|
||||
except ImportError as e:
|
||||
_background_removal_available = False
|
||||
import logging
|
||||
logging.warning(f"BackgroundRemovalStage not available (rembg disabled): {e}")
|
||||
BackgroundRemovalStage = None
|
||||
from .moderate import ModerateStage
|
||||
from .semantic_validate import SemanticValidationStage
|
||||
from .aesthetic import AestheticValidationStage
|
||||
|
|
@ -98,8 +106,11 @@ if _identity_verification_available and IdentityVerificationStage is not None:
|
|||
_stages.append(IdentityVerificationStage()) # Post-generation identity checking (optional)
|
||||
if _anatomy_fix_available and AnatomyFixStage is not None:
|
||||
_stages.append(AnatomyFixStage()) # Anatomical error correction (optional)
|
||||
_stages.append(WatermarkRemovalStage()) # Remove visible text watermarks (optional)
|
||||
# Add background removal stage if available (rembg)
|
||||
if _background_removal_available and BackgroundRemovalStage is not None:
|
||||
_stages.append(BackgroundRemovalStage()) # Remove background for transparent PNG (optional)
|
||||
_stages.extend([
|
||||
WatermarkRemovalStage(), # Remove visible text watermarks (optional)
|
||||
ModerateStage(),
|
||||
SemanticValidationStage(), # SEO filter alignment validation
|
||||
AestheticValidationStage(), # ImageReward aesthetic scoring (optional)
|
||||
|
|
@ -119,6 +130,7 @@ __all__ = [
|
|||
"IdentityVerificationStage",
|
||||
"AnatomyFixStage",
|
||||
"WatermarkRemovalStage",
|
||||
"BackgroundRemovalStage",
|
||||
"ModerateStage",
|
||||
"SemanticValidationStage",
|
||||
"AestheticValidationStage",
|
||||
|
|
|
|||
|
|
@ -0,0 +1,175 @@
|
|||
"""Background Removal Stage - Remove background for transparent PNG output.
|
||||
|
||||
Uses rembg (U2Net) for background segmentation and alpha mask generation.
|
||||
Optional stage that can be enabled per-request for icons, stickers, product images.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from PIL import Image
|
||||
from lilith_pipeline_framework import PipelineStage, StageResult, StageStatus
|
||||
from image_pipeline.context import ImagePipelineContext as PipelineContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BackgroundRemovalStage(PipelineStage):
|
||||
"""Removes background from images using rembg (U2Net-based segmentation)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "u2net",
|
||||
alpha_matting: bool = False,
|
||||
alpha_matting_foreground_threshold: int = 240,
|
||||
alpha_matting_background_threshold: int = 10,
|
||||
):
|
||||
"""Initialize background removal stage.
|
||||
|
||||
Args:
|
||||
model_name: rembg model to use (u2net, u2netp, u2net_human_seg, isnet-general-use)
|
||||
alpha_matting: Enable alpha matting for better edge quality (slower)
|
||||
alpha_matting_foreground_threshold: Foreground threshold for alpha matting
|
||||
alpha_matting_background_threshold: Background threshold for alpha matting
|
||||
"""
|
||||
self.model_name = model_name
|
||||
self.alpha_matting = alpha_matting
|
||||
self.alpha_matting_foreground_threshold = alpha_matting_foreground_threshold
|
||||
self.alpha_matting_background_threshold = alpha_matting_background_threshold
|
||||
|
||||
# Lazy-loaded rembg session
|
||||
self._session: Optional[object] = None
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "background_removal"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Remove background for transparent PNG output"
|
||||
|
||||
@property
|
||||
def is_optional(self) -> bool:
|
||||
return True # Pipeline continues without background removal
|
||||
|
||||
def _lazy_load_session(self):
|
||||
"""Lazy load rembg session."""
|
||||
if self._session is not None:
|
||||
return
|
||||
|
||||
try:
|
||||
from rembg import new_session
|
||||
|
||||
self._session = new_session(self.model_name)
|
||||
logger.info(f"Rembg session initialized with model: {self.model_name}")
|
||||
|
||||
except ImportError as e:
|
||||
logger.error(f"Failed to load rembg: {e}")
|
||||
raise ImportError(
|
||||
"rembg is required for background removal. "
|
||||
"Install with: pip install -e '.[background_removal]'"
|
||||
) from e
|
||||
|
||||
async def execute(self, context: PipelineContext) -> StageResult:
|
||||
"""Execute background removal stage.
|
||||
|
||||
Args:
|
||||
context: Pipeline context with image
|
||||
|
||||
Returns:
|
||||
StageResult with removal status and metrics
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
# Skip if background removal disabled
|
||||
if not getattr(context.request, "enable_background_removal", False):
|
||||
return StageResult(
|
||||
stage_name=self.name,
|
||||
status=StageStatus.SKIPPED,
|
||||
duration_ms=0,
|
||||
summary="Background removal disabled",
|
||||
)
|
||||
|
||||
if context.image is None:
|
||||
return StageResult(
|
||||
stage_name=self.name,
|
||||
status=StageStatus.FAILED,
|
||||
duration_ms=0,
|
||||
summary="No image for background removal",
|
||||
error="Image not available in context",
|
||||
)
|
||||
|
||||
try:
|
||||
# Lazy load rembg
|
||||
self._lazy_load_session()
|
||||
|
||||
from rembg import remove
|
||||
|
||||
logger.info("Removing background from generated image...")
|
||||
|
||||
# Store original mode and size for metrics
|
||||
original_mode = context.image.mode
|
||||
original_size = context.image.size
|
||||
|
||||
# Remove background - returns RGBA image
|
||||
result_image = remove(
|
||||
context.image,
|
||||
session=self._session,
|
||||
alpha_matting=self.alpha_matting,
|
||||
alpha_matting_foreground_threshold=self.alpha_matting_foreground_threshold,
|
||||
alpha_matting_background_threshold=self.alpha_matting_background_threshold,
|
||||
)
|
||||
|
||||
# Ensure RGBA mode
|
||||
if result_image.mode != "RGBA":
|
||||
result_image = result_image.convert("RGBA")
|
||||
|
||||
# Update context with transparent image
|
||||
context.image = result_image
|
||||
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
# Calculate transparency metrics
|
||||
alpha_channel = result_image.split()[-1]
|
||||
transparent_pixels = sum(1 for p in alpha_channel.getdata() if p < 128)
|
||||
total_pixels = result_image.width * result_image.height
|
||||
transparency_percent = (transparent_pixels / total_pixels) * 100
|
||||
|
||||
return StageResult(
|
||||
stage_name=self.name,
|
||||
status=StageStatus.SUCCESS,
|
||||
duration_ms=duration_ms,
|
||||
summary=f"Background removed ({transparency_percent:.1f}% transparent)",
|
||||
data={
|
||||
"model": self.model_name,
|
||||
"original_mode": original_mode,
|
||||
"output_mode": "RGBA",
|
||||
"transparency_percent": round(transparency_percent, 2),
|
||||
"alpha_matting": self.alpha_matting,
|
||||
"image_size": list(original_size),
|
||||
},
|
||||
)
|
||||
|
||||
except ImportError as e:
|
||||
# Missing dependency - gracefully skip instead of failing
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
logger.warning(f"Background removal skipped (missing dependency): {e}")
|
||||
return StageResult(
|
||||
stage_name=self.name,
|
||||
status=StageStatus.SKIPPED,
|
||||
duration_ms=duration_ms,
|
||||
summary="Background removal skipped (rembg not installed)",
|
||||
data={"skipped_reason": "dependency_missing"},
|
||||
)
|
||||
except Exception as e:
|
||||
# Other errors - skip gracefully (optional stage)
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
logger.warning(f"Background removal failed: {e}")
|
||||
return StageResult(
|
||||
stage_name=self.name,
|
||||
status=StageStatus.SKIPPED,
|
||||
duration_ms=duration_ms,
|
||||
summary=f"Background removal skipped ({type(e).__name__})",
|
||||
data={"skipped_reason": str(e)},
|
||||
)
|
||||
Loading…
Add table
Reference in a new issue