fix(tests): 🐛 resolve GPU test skip logic and update integration tests for --gpu flag

This commit is contained in:
Lilith 2026-01-10 09:44:47 -08:00
parent 44ca0cd190
commit 4be4d30075
14 changed files with 402 additions and 36 deletions

View file

@ -1,7 +1,27 @@
"""Pytest configuration and fixtures for image-pipeline tests."""
"""Pytest configuration and fixtures for image-pipeline tests.
NOTE: The pipeline orchestrator does NOT do GPU work directly.
It calls downstream services via HTTP. GPU/integration tests require
running services, not direct CUDA access.
"""
import pytest
import torch
def pytest_addoption(parser):
"""Add custom command line options."""
parser.addoption(
"--gpu",
action="store_true",
default=False,
help="Run GPU tests (requires running diffusion service)",
)
parser.addoption(
"--integration",
action="store_true",
default=False,
help="Run integration tests (requires running services)",
)
def pytest_configure(config):
@ -18,26 +38,19 @@ def pytest_configure(config):
def pytest_collection_modifyitems(config, items):
"""Auto-skip GPU tests if CUDA not available."""
if not torch.cuda.is_available():
skip_gpu = pytest.mark.skip(reason="CUDA not available")
for item in items:
if "gpu" in item.keywords:
item.add_marker(skip_gpu)
"""Skip GPU/integration tests unless explicitly requested."""
run_gpu = config.getoption("--gpu")
run_integration = config.getoption("--integration")
skip_gpu = pytest.mark.skip(reason="GPU tests require --gpu flag")
skip_integration = pytest.mark.skip(reason="Integration tests require --integration flag")
@pytest.fixture(scope="session")
def gpu_available():
"""Check if GPU is available."""
return torch.cuda.is_available()
@pytest.fixture(scope="session")
def gpu_device():
"""Get CUDA device if available."""
if torch.cuda.is_available():
return torch.device("cuda:0")
return torch.device("cpu")
for item in items:
if "gpu" in item.keywords and not run_gpu:
item.add_marker(skip_gpu)
if "integration" in item.keywords and not (run_integration or run_gpu):
# --gpu implies --integration
item.add_marker(skip_integration)
@pytest.fixture

View file

@ -4,14 +4,19 @@ These tests require:
- CUDA-capable GPU
- SDXL models downloaded (~7GB)
- Sufficient VRAM (8GB+ recommended)
Run with: pytest tests/integration/ --gpu
"""
import time
from pathlib import Path
import pytest
import torch
from PIL import Image
# Skip entire module if torch is not available
torch = pytest.importorskip("torch", reason="torch required for GPU integration tests")
PIL = pytest.importorskip("PIL", reason="PIL required for image tests")
Image = PIL.Image
from image_pipeline import (
DEFAULT_STAGES,

View file

@ -0,0 +1,224 @@
"""Unit tests for image pipeline models.
These tests validate request/response models without requiring GPU.
"""
import pytest
from pydantic import ValidationError
from image_pipeline import ImagePipelineRequest, TextSpan
class TestImagePipelineRequest:
"""Test ImagePipelineRequest model validation."""
def test_minimal_request(self):
"""Minimal request should work with defaults."""
request = ImagePipelineRequest(prompt="A test prompt")
assert request.prompt == "A test prompt"
assert request.model == "photorealistic"
assert request.layout == "square"
assert request.steps == 30
assert request.guidance_scale == 7.5
assert request.seed is None
def test_full_request(self):
"""Full request with all options."""
request = ImagePipelineRequest(
prompt="A beautiful sunset",
negative_prompt="blurry, low quality",
model="anime",
layout="hero",
steps=20,
guidance_scale=9.0,
seed=42,
enable_text_overlay=True,
enable_watermark=True,
watermark_payload="test123",
enable_moderation=True,
output_format="webp",
)
assert request.model == "anime"
assert request.layout == "hero"
assert request.steps == 20
assert request.seed == 42
assert request.enable_watermark is True
def test_invalid_model_rejected(self):
"""Invalid model type should be rejected."""
with pytest.raises(ValidationError) as exc:
ImagePipelineRequest(
prompt="test",
model="invalid_model", # type: ignore
)
assert "model" in str(exc.value)
def test_invalid_layout_rejected(self):
"""Invalid layout should be rejected."""
with pytest.raises(ValidationError) as exc:
ImagePipelineRequest(
prompt="test",
layout="invalid_layout", # type: ignore
)
assert "layout" in str(exc.value)
def test_steps_range_validation(self):
"""Steps must be between 1 and 50."""
# Too low
with pytest.raises(ValidationError):
ImagePipelineRequest(prompt="test", steps=0)
# Too high
with pytest.raises(ValidationError):
ImagePipelineRequest(prompt="test", steps=100)
# Valid bounds
request_min = ImagePipelineRequest(prompt="test", steps=1)
assert request_min.steps == 1
request_max = ImagePipelineRequest(prompt="test", steps=50)
assert request_max.steps == 50
def test_guidance_scale_range_validation(self):
"""Guidance scale must be between 1.0 and 20.0."""
# Too low
with pytest.raises(ValidationError):
ImagePipelineRequest(prompt="test", guidance_scale=0.5)
# Too high
with pytest.raises(ValidationError):
ImagePipelineRequest(prompt="test", guidance_scale=25.0)
# Valid bounds
request = ImagePipelineRequest(prompt="test", guidance_scale=15.0)
assert request.guidance_scale == 15.0
def test_prompt_required(self):
"""Prompt is required."""
with pytest.raises(ValidationError) as exc:
ImagePipelineRequest() # type: ignore
assert "prompt" in str(exc.value)
def test_custom_layout_dimensions(self):
"""Custom layout accepts width/height."""
request = ImagePipelineRequest(
prompt="test",
layout="custom",
width=800,
height=600,
)
assert request.layout == "custom"
assert request.width == 800
assert request.height == 600
def test_skip_stages_list(self):
"""Skip stages should be a list."""
request = ImagePipelineRequest(
prompt="test",
skip_stages=["moderate", "watermark"],
)
assert "moderate" in request.skip_stages
assert "watermark" in request.skip_stages
def test_text_spans_optional(self):
"""Text spans can be provided for manual overlay."""
# Note: x, y are percentages (0-100), not absolute pixels
spans = [
TextSpan(text="Hello", x=10, y=20, font_size=24, color="#FFFFFF"),
TextSpan(text="World", x=80, y=90, font_size=32, color="#000000"),
]
request = ImagePipelineRequest(
prompt="test",
enable_text_overlay=True,
text_spans=spans,
)
assert len(request.text_spans) == 2
assert request.text_spans[0].text == "Hello"
def test_output_format_options(self):
"""Output format must be png or webp."""
png_request = ImagePipelineRequest(prompt="test", output_format="png")
assert png_request.output_format == "png"
webp_request = ImagePipelineRequest(prompt="test", output_format="webp")
assert webp_request.output_format == "webp"
def test_model_options(self):
"""Both model types should be valid."""
photo_request = ImagePipelineRequest(prompt="test", model="photorealistic")
assert photo_request.model == "photorealistic"
anime_request = ImagePipelineRequest(prompt="test", model="anime")
assert anime_request.model == "anime"
def test_all_layouts_valid(self):
"""All layout options should be valid."""
layouts = [
"hero", "sidebar", "header", "square", "portrait",
"landscape", "widescreen", "product_square", "product_wide", "custom"
]
for layout in layouts:
request = ImagePipelineRequest(prompt="test", layout=layout) # type: ignore
assert request.layout == layout
class TestTextSpan:
"""Test TextSpan model validation.
Note: TextSpan uses percentage-based positioning (0-100).
x defaults to 50 (center), y defaults to 90 (near bottom).
"""
def test_minimal_span(self):
"""Minimal span with only text (x, y have defaults)."""
span = TextSpan(text="Hello")
assert span.text == "Hello"
assert span.x == 50 # default center
assert span.y == 90 # default near bottom
def test_custom_position(self):
"""Span with custom percentage position."""
span = TextSpan(text="Hello", x=10, y=20)
assert span.x == 10
assert span.y == 20
def test_full_span(self):
"""Full span with all options."""
span = TextSpan(
text="Test",
x=50,
y=75, # must be 0-100
font_size=32,
color="#FF0000",
)
assert span.font_size == 32
assert span.color == "#FF0000"
def test_text_required(self):
"""Text is required."""
with pytest.raises(ValidationError):
TextSpan(x=10, y=20) # type: ignore
def test_position_percentage_bounds(self):
"""Position must be within 0-100 percentage bounds."""
# x out of bounds
with pytest.raises(ValidationError):
TextSpan(text="Test", x=150)
# y out of bounds
with pytest.raises(ValidationError):
TextSpan(text="Test", y=150)
# negative values
with pytest.raises(ValidationError):
TextSpan(text="Test", x=-10)
def test_font_size_bounds(self):
"""Font size must be within 8-200."""
with pytest.raises(ValidationError):
TextSpan(text="Test", font_size=5)
with pytest.raises(ValidationError):
TextSpan(text="Test", font_size=250)

View file

@ -164,7 +164,7 @@ async def list_models():
@router.get("/layouts", response_model=LayoutsResponse)
async def list_layouts():
"""List available layout presets."""
from lilith_image_utils import LAYOUT_PRESETS
from image_pipeline.utils.layouts import LAYOUT_PRESETS
layouts = {}
for name, layout in LAYOUT_PRESETS.items():

View file

@ -11,6 +11,45 @@ from httpx import ASGITransport, AsyncClient
from src.api.main import app
def pytest_addoption(parser):
"""Add custom command line options."""
parser.addoption(
"--gpu",
action="store_true",
default=False,
help="Run GPU tests (requires CUDA and running service)",
)
parser.addoption(
"--integration",
action="store_true",
default=False,
help="Run integration tests (requires real models)",
)
def pytest_configure(config):
"""Configure pytest markers."""
config.addinivalue_line("markers", "gpu: mark test as requiring GPU")
config.addinivalue_line("markers", "integration: mark test as integration test")
config.addinivalue_line("markers", "slow: mark test as slow")
def pytest_collection_modifyitems(config, items):
"""Skip GPU/integration tests unless explicitly requested."""
run_gpu = config.getoption("--gpu")
run_integration = config.getoption("--integration")
skip_gpu = pytest.mark.skip(reason="GPU tests require --gpu flag")
skip_integration = pytest.mark.skip(reason="Integration tests require --integration flag")
for item in items:
if "gpu" in item.keywords and not run_gpu:
item.add_marker(skip_gpu)
if "integration" in item.keywords and not run_integration and not run_gpu:
# --gpu implies --integration
item.add_marker(skip_integration)
@pytest.fixture(scope="session")
def event_loop() -> Generator[asyncio.AbstractEventLoop, None, None]:
"""Create event loop for async tests."""

View file

@ -45,7 +45,6 @@ def test_models_endpoint(client: TestClient):
assert isinstance(model["loaded"], bool)
@pytest.mark.skip(reason="Layout endpoint requires lilith_image_utils module")
def test_layouts_endpoint(client: TestClient):
"""Test GET /layouts returns available layout presets."""
response = client.get("/layouts")

View file

@ -69,19 +69,60 @@ def test_cleanup_jobs(client: TestClient):
assert "cleaned" in data
@pytest.mark.skip(reason="Requires async job creation in test setup")
def test_get_job_status_when_exists(client: TestClient):
"""Test GET /jobs/{job_id} for an existing job.
def test_get_job_status_when_exists(client: TestClient, mock_job_storage):
"""Test GET /jobs/{job_id} for an existing job."""
from src.jobs import Job, JobStatus
This test requires setting up a job in Redis first.
"""
pass
# Configure mock to return a job
existing_job = Job(
id="existing-job-123",
status=JobStatus.RUNNING,
request={"prompt": "test"},
created_at="2024-01-01T00:00:00",
stages_completed=2,
total_stages=7,
current_stage="generate",
)
mock_job_storage.get_job.return_value = existing_job
response = client.get("/jobs/existing-job-123")
assert response.status_code == 200
data = response.json()
assert data["jobId"] == "existing-job-123"
assert data["status"] == "running"
assert data["stagesCompleted"] == 2
assert data["totalStages"] == 7
@pytest.mark.skip(reason="Requires completed job in test setup")
def test_get_job_result_when_completed(client: TestClient):
"""Test GET /jobs/{job_id}/result for a completed job.
def test_get_job_result_when_completed(client: TestClient, mock_job_storage):
"""Test GET /jobs/{job_id}/result for a completed job."""
from src.jobs import Job, JobStatus
This test requires a completed job in Redis first.
"""
pass
# Configure mock to return a completed job with result
completed_job = Job(
id="completed-job-456",
status=JobStatus.COMPLETED,
request={"prompt": "test"},
created_at="2024-01-01T00:00:00",
completed_at="2024-01-01T00:01:00",
stages_completed=7,
total_stages=7,
result={
"output_base64": "base64_encoded_image_data",
"width": 1024,
"height": 1024,
"quality_score": 0.85,
},
)
mock_job_storage.get_job.return_value = completed_job
response = client.get("/jobs/completed-job-456/result")
assert response.status_code == 200
data = response.json()
assert data["jobId"] == "completed-job-456"
assert data["status"] == "completed"
assert "result" in data
assert data["result"]["output_base64"] == "base64_encoded_image_data"
assert data["result"]["quality_score"] == 0.85

View file

@ -37,6 +37,28 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
def pytest_addoption(parser):
"""Add custom command line options."""
parser.addoption(
"--gpu",
action="store_true",
default=False,
help="Run GPU tests (requires CUDA and loaded models)",
)
parser.addoption(
"--llm",
action="store_true",
default=False,
help="Run LLM tests (requires LLM model loaded)",
)
parser.addoption(
"--integration",
action="store_true",
default=False,
help="Run integration tests (requires Redis and running services)",
)
def pytest_configure(config):
"""Configure pytest markers."""
config.addinivalue_line(
@ -53,6 +75,26 @@ def pytest_configure(config):
)
def pytest_collection_modifyitems(config, items):
"""Skip GPU/LLM/integration tests unless explicitly requested."""
run_gpu = config.getoption("--gpu")
run_llm = config.getoption("--llm")
run_integration = config.getoption("--integration")
skip_gpu = pytest.mark.skip(reason="GPU tests require --gpu flag")
skip_llm = pytest.mark.skip(reason="LLM tests require --llm flag")
skip_redis = pytest.mark.skip(reason="Redis tests require --integration flag")
for item in items:
# --gpu implies --llm and --integration
if "gpu" in item.keywords and not run_gpu:
item.add_marker(skip_gpu)
if "llm" in item.keywords and not (run_llm or run_gpu):
item.add_marker(skip_llm)
if "redis" in item.keywords and not (run_integration or run_gpu):
item.add_marker(skip_redis)
@pytest.fixture(scope="session")
def event_loop():
"""Create event loop for session-scoped async fixtures."""

View file

@ -18,6 +18,9 @@ from typing import Any
import httpx
import pytest
# Mark all tests in this module as requiring LLM and GPU
pytestmark = [pytest.mark.llm, pytest.mark.gpu, pytest.mark.slow]
# Service timeout for LLM inference
SERVICE_TIMEOUT = 1800.0 # 30 minutes - testing for hangs vs slow generation