fix(tests): 🐛 resolve GPU test skip logic and update integration tests for --gpu flag
This commit is contained in:
parent
44ca0cd190
commit
4be4d30075
14 changed files with 402 additions and 36 deletions
Binary file not shown.
Binary file not shown.
|
|
@ -1,7 +1,27 @@
|
|||
"""Pytest configuration and fixtures for image-pipeline tests."""
|
||||
"""Pytest configuration and fixtures for image-pipeline tests.
|
||||
|
||||
NOTE: The pipeline orchestrator does NOT do GPU work directly.
|
||||
It calls downstream services via HTTP. GPU/integration tests require
|
||||
running services, not direct CUDA access.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
"""Add custom command line options."""
|
||||
parser.addoption(
|
||||
"--gpu",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Run GPU tests (requires running diffusion service)",
|
||||
)
|
||||
parser.addoption(
|
||||
"--integration",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Run integration tests (requires running services)",
|
||||
)
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
|
|
@ -18,26 +38,19 @@ def pytest_configure(config):
|
|||
|
||||
|
||||
def pytest_collection_modifyitems(config, items):
|
||||
"""Auto-skip GPU tests if CUDA not available."""
|
||||
if not torch.cuda.is_available():
|
||||
skip_gpu = pytest.mark.skip(reason="CUDA not available")
|
||||
for item in items:
|
||||
if "gpu" in item.keywords:
|
||||
item.add_marker(skip_gpu)
|
||||
"""Skip GPU/integration tests unless explicitly requested."""
|
||||
run_gpu = config.getoption("--gpu")
|
||||
run_integration = config.getoption("--integration")
|
||||
|
||||
skip_gpu = pytest.mark.skip(reason="GPU tests require --gpu flag")
|
||||
skip_integration = pytest.mark.skip(reason="Integration tests require --integration flag")
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def gpu_available():
|
||||
"""Check if GPU is available."""
|
||||
return torch.cuda.is_available()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def gpu_device():
|
||||
"""Get CUDA device if available."""
|
||||
if torch.cuda.is_available():
|
||||
return torch.device("cuda:0")
|
||||
return torch.device("cpu")
|
||||
for item in items:
|
||||
if "gpu" in item.keywords and not run_gpu:
|
||||
item.add_marker(skip_gpu)
|
||||
if "integration" in item.keywords and not (run_integration or run_gpu):
|
||||
# --gpu implies --integration
|
||||
item.add_marker(skip_integration)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
|
|||
Binary file not shown.
|
|
@ -4,14 +4,19 @@ These tests require:
|
|||
- CUDA-capable GPU
|
||||
- SDXL models downloaded (~7GB)
|
||||
- Sufficient VRAM (8GB+ recommended)
|
||||
|
||||
Run with: pytest tests/integration/ --gpu
|
||||
"""
|
||||
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
# Skip entire module if torch is not available
|
||||
torch = pytest.importorskip("torch", reason="torch required for GPU integration tests")
|
||||
PIL = pytest.importorskip("PIL", reason="PIL required for image tests")
|
||||
Image = PIL.Image
|
||||
|
||||
from image_pipeline import (
|
||||
DEFAULT_STAGES,
|
||||
|
|
|
|||
224
orchestrators/imajin-pipeline/tests/test_models.py
Normal file
224
orchestrators/imajin-pipeline/tests/test_models.py
Normal file
|
|
@ -0,0 +1,224 @@
|
|||
"""Unit tests for image pipeline models.
|
||||
|
||||
These tests validate request/response models without requiring GPU.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from image_pipeline import ImagePipelineRequest, TextSpan
|
||||
|
||||
|
||||
class TestImagePipelineRequest:
|
||||
"""Test ImagePipelineRequest model validation."""
|
||||
|
||||
def test_minimal_request(self):
|
||||
"""Minimal request should work with defaults."""
|
||||
request = ImagePipelineRequest(prompt="A test prompt")
|
||||
|
||||
assert request.prompt == "A test prompt"
|
||||
assert request.model == "photorealistic"
|
||||
assert request.layout == "square"
|
||||
assert request.steps == 30
|
||||
assert request.guidance_scale == 7.5
|
||||
assert request.seed is None
|
||||
|
||||
def test_full_request(self):
|
||||
"""Full request with all options."""
|
||||
request = ImagePipelineRequest(
|
||||
prompt="A beautiful sunset",
|
||||
negative_prompt="blurry, low quality",
|
||||
model="anime",
|
||||
layout="hero",
|
||||
steps=20,
|
||||
guidance_scale=9.0,
|
||||
seed=42,
|
||||
enable_text_overlay=True,
|
||||
enable_watermark=True,
|
||||
watermark_payload="test123",
|
||||
enable_moderation=True,
|
||||
output_format="webp",
|
||||
)
|
||||
|
||||
assert request.model == "anime"
|
||||
assert request.layout == "hero"
|
||||
assert request.steps == 20
|
||||
assert request.seed == 42
|
||||
assert request.enable_watermark is True
|
||||
|
||||
def test_invalid_model_rejected(self):
|
||||
"""Invalid model type should be rejected."""
|
||||
with pytest.raises(ValidationError) as exc:
|
||||
ImagePipelineRequest(
|
||||
prompt="test",
|
||||
model="invalid_model", # type: ignore
|
||||
)
|
||||
assert "model" in str(exc.value)
|
||||
|
||||
def test_invalid_layout_rejected(self):
|
||||
"""Invalid layout should be rejected."""
|
||||
with pytest.raises(ValidationError) as exc:
|
||||
ImagePipelineRequest(
|
||||
prompt="test",
|
||||
layout="invalid_layout", # type: ignore
|
||||
)
|
||||
assert "layout" in str(exc.value)
|
||||
|
||||
def test_steps_range_validation(self):
|
||||
"""Steps must be between 1 and 50."""
|
||||
# Too low
|
||||
with pytest.raises(ValidationError):
|
||||
ImagePipelineRequest(prompt="test", steps=0)
|
||||
|
||||
# Too high
|
||||
with pytest.raises(ValidationError):
|
||||
ImagePipelineRequest(prompt="test", steps=100)
|
||||
|
||||
# Valid bounds
|
||||
request_min = ImagePipelineRequest(prompt="test", steps=1)
|
||||
assert request_min.steps == 1
|
||||
|
||||
request_max = ImagePipelineRequest(prompt="test", steps=50)
|
||||
assert request_max.steps == 50
|
||||
|
||||
def test_guidance_scale_range_validation(self):
|
||||
"""Guidance scale must be between 1.0 and 20.0."""
|
||||
# Too low
|
||||
with pytest.raises(ValidationError):
|
||||
ImagePipelineRequest(prompt="test", guidance_scale=0.5)
|
||||
|
||||
# Too high
|
||||
with pytest.raises(ValidationError):
|
||||
ImagePipelineRequest(prompt="test", guidance_scale=25.0)
|
||||
|
||||
# Valid bounds
|
||||
request = ImagePipelineRequest(prompt="test", guidance_scale=15.0)
|
||||
assert request.guidance_scale == 15.0
|
||||
|
||||
def test_prompt_required(self):
|
||||
"""Prompt is required."""
|
||||
with pytest.raises(ValidationError) as exc:
|
||||
ImagePipelineRequest() # type: ignore
|
||||
assert "prompt" in str(exc.value)
|
||||
|
||||
def test_custom_layout_dimensions(self):
|
||||
"""Custom layout accepts width/height."""
|
||||
request = ImagePipelineRequest(
|
||||
prompt="test",
|
||||
layout="custom",
|
||||
width=800,
|
||||
height=600,
|
||||
)
|
||||
assert request.layout == "custom"
|
||||
assert request.width == 800
|
||||
assert request.height == 600
|
||||
|
||||
def test_skip_stages_list(self):
|
||||
"""Skip stages should be a list."""
|
||||
request = ImagePipelineRequest(
|
||||
prompt="test",
|
||||
skip_stages=["moderate", "watermark"],
|
||||
)
|
||||
assert "moderate" in request.skip_stages
|
||||
assert "watermark" in request.skip_stages
|
||||
|
||||
def test_text_spans_optional(self):
|
||||
"""Text spans can be provided for manual overlay."""
|
||||
# Note: x, y are percentages (0-100), not absolute pixels
|
||||
spans = [
|
||||
TextSpan(text="Hello", x=10, y=20, font_size=24, color="#FFFFFF"),
|
||||
TextSpan(text="World", x=80, y=90, font_size=32, color="#000000"),
|
||||
]
|
||||
request = ImagePipelineRequest(
|
||||
prompt="test",
|
||||
enable_text_overlay=True,
|
||||
text_spans=spans,
|
||||
)
|
||||
assert len(request.text_spans) == 2
|
||||
assert request.text_spans[0].text == "Hello"
|
||||
|
||||
def test_output_format_options(self):
|
||||
"""Output format must be png or webp."""
|
||||
png_request = ImagePipelineRequest(prompt="test", output_format="png")
|
||||
assert png_request.output_format == "png"
|
||||
|
||||
webp_request = ImagePipelineRequest(prompt="test", output_format="webp")
|
||||
assert webp_request.output_format == "webp"
|
||||
|
||||
def test_model_options(self):
|
||||
"""Both model types should be valid."""
|
||||
photo_request = ImagePipelineRequest(prompt="test", model="photorealistic")
|
||||
assert photo_request.model == "photorealistic"
|
||||
|
||||
anime_request = ImagePipelineRequest(prompt="test", model="anime")
|
||||
assert anime_request.model == "anime"
|
||||
|
||||
def test_all_layouts_valid(self):
|
||||
"""All layout options should be valid."""
|
||||
layouts = [
|
||||
"hero", "sidebar", "header", "square", "portrait",
|
||||
"landscape", "widescreen", "product_square", "product_wide", "custom"
|
||||
]
|
||||
for layout in layouts:
|
||||
request = ImagePipelineRequest(prompt="test", layout=layout) # type: ignore
|
||||
assert request.layout == layout
|
||||
|
||||
|
||||
class TestTextSpan:
|
||||
"""Test TextSpan model validation.
|
||||
|
||||
Note: TextSpan uses percentage-based positioning (0-100).
|
||||
x defaults to 50 (center), y defaults to 90 (near bottom).
|
||||
"""
|
||||
|
||||
def test_minimal_span(self):
|
||||
"""Minimal span with only text (x, y have defaults)."""
|
||||
span = TextSpan(text="Hello")
|
||||
assert span.text == "Hello"
|
||||
assert span.x == 50 # default center
|
||||
assert span.y == 90 # default near bottom
|
||||
|
||||
def test_custom_position(self):
|
||||
"""Span with custom percentage position."""
|
||||
span = TextSpan(text="Hello", x=10, y=20)
|
||||
assert span.x == 10
|
||||
assert span.y == 20
|
||||
|
||||
def test_full_span(self):
|
||||
"""Full span with all options."""
|
||||
span = TextSpan(
|
||||
text="Test",
|
||||
x=50,
|
||||
y=75, # must be 0-100
|
||||
font_size=32,
|
||||
color="#FF0000",
|
||||
)
|
||||
assert span.font_size == 32
|
||||
assert span.color == "#FF0000"
|
||||
|
||||
def test_text_required(self):
|
||||
"""Text is required."""
|
||||
with pytest.raises(ValidationError):
|
||||
TextSpan(x=10, y=20) # type: ignore
|
||||
|
||||
def test_position_percentage_bounds(self):
|
||||
"""Position must be within 0-100 percentage bounds."""
|
||||
# x out of bounds
|
||||
with pytest.raises(ValidationError):
|
||||
TextSpan(text="Test", x=150)
|
||||
|
||||
# y out of bounds
|
||||
with pytest.raises(ValidationError):
|
||||
TextSpan(text="Test", y=150)
|
||||
|
||||
# negative values
|
||||
with pytest.raises(ValidationError):
|
||||
TextSpan(text="Test", x=-10)
|
||||
|
||||
def test_font_size_bounds(self):
|
||||
"""Font size must be within 8-200."""
|
||||
with pytest.raises(ValidationError):
|
||||
TextSpan(text="Test", font_size=5)
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
TextSpan(text="Test", font_size=250)
|
||||
|
|
@ -164,7 +164,7 @@ async def list_models():
|
|||
@router.get("/layouts", response_model=LayoutsResponse)
|
||||
async def list_layouts():
|
||||
"""List available layout presets."""
|
||||
from lilith_image_utils import LAYOUT_PRESETS
|
||||
from image_pipeline.utils.layouts import LAYOUT_PRESETS
|
||||
|
||||
layouts = {}
|
||||
for name, layout in LAYOUT_PRESETS.items():
|
||||
|
|
|
|||
|
|
@ -11,6 +11,45 @@ from httpx import ASGITransport, AsyncClient
|
|||
from src.api.main import app
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
"""Add custom command line options."""
|
||||
parser.addoption(
|
||||
"--gpu",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Run GPU tests (requires CUDA and running service)",
|
||||
)
|
||||
parser.addoption(
|
||||
"--integration",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Run integration tests (requires real models)",
|
||||
)
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
"""Configure pytest markers."""
|
||||
config.addinivalue_line("markers", "gpu: mark test as requiring GPU")
|
||||
config.addinivalue_line("markers", "integration: mark test as integration test")
|
||||
config.addinivalue_line("markers", "slow: mark test as slow")
|
||||
|
||||
|
||||
def pytest_collection_modifyitems(config, items):
|
||||
"""Skip GPU/integration tests unless explicitly requested."""
|
||||
run_gpu = config.getoption("--gpu")
|
||||
run_integration = config.getoption("--integration")
|
||||
|
||||
skip_gpu = pytest.mark.skip(reason="GPU tests require --gpu flag")
|
||||
skip_integration = pytest.mark.skip(reason="Integration tests require --integration flag")
|
||||
|
||||
for item in items:
|
||||
if "gpu" in item.keywords and not run_gpu:
|
||||
item.add_marker(skip_gpu)
|
||||
if "integration" in item.keywords and not run_integration and not run_gpu:
|
||||
# --gpu implies --integration
|
||||
item.add_marker(skip_integration)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def event_loop() -> Generator[asyncio.AbstractEventLoop, None, None]:
|
||||
"""Create event loop for async tests."""
|
||||
|
|
|
|||
|
|
@ -45,7 +45,6 @@ def test_models_endpoint(client: TestClient):
|
|||
assert isinstance(model["loaded"], bool)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Layout endpoint requires lilith_image_utils module")
|
||||
def test_layouts_endpoint(client: TestClient):
|
||||
"""Test GET /layouts returns available layout presets."""
|
||||
response = client.get("/layouts")
|
||||
|
|
|
|||
|
|
@ -69,19 +69,60 @@ def test_cleanup_jobs(client: TestClient):
|
|||
assert "cleaned" in data
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Requires async job creation in test setup")
|
||||
def test_get_job_status_when_exists(client: TestClient):
|
||||
"""Test GET /jobs/{job_id} for an existing job.
|
||||
def test_get_job_status_when_exists(client: TestClient, mock_job_storage):
|
||||
"""Test GET /jobs/{job_id} for an existing job."""
|
||||
from src.jobs import Job, JobStatus
|
||||
|
||||
This test requires setting up a job in Redis first.
|
||||
"""
|
||||
pass
|
||||
# Configure mock to return a job
|
||||
existing_job = Job(
|
||||
id="existing-job-123",
|
||||
status=JobStatus.RUNNING,
|
||||
request={"prompt": "test"},
|
||||
created_at="2024-01-01T00:00:00",
|
||||
stages_completed=2,
|
||||
total_stages=7,
|
||||
current_stage="generate",
|
||||
)
|
||||
mock_job_storage.get_job.return_value = existing_job
|
||||
|
||||
response = client.get("/jobs/existing-job-123")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["jobId"] == "existing-job-123"
|
||||
assert data["status"] == "running"
|
||||
assert data["stagesCompleted"] == 2
|
||||
assert data["totalStages"] == 7
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Requires completed job in test setup")
|
||||
def test_get_job_result_when_completed(client: TestClient):
|
||||
"""Test GET /jobs/{job_id}/result for a completed job.
|
||||
def test_get_job_result_when_completed(client: TestClient, mock_job_storage):
|
||||
"""Test GET /jobs/{job_id}/result for a completed job."""
|
||||
from src.jobs import Job, JobStatus
|
||||
|
||||
This test requires a completed job in Redis first.
|
||||
"""
|
||||
pass
|
||||
# Configure mock to return a completed job with result
|
||||
completed_job = Job(
|
||||
id="completed-job-456",
|
||||
status=JobStatus.COMPLETED,
|
||||
request={"prompt": "test"},
|
||||
created_at="2024-01-01T00:00:00",
|
||||
completed_at="2024-01-01T00:01:00",
|
||||
stages_completed=7,
|
||||
total_stages=7,
|
||||
result={
|
||||
"output_base64": "base64_encoded_image_data",
|
||||
"width": 1024,
|
||||
"height": 1024,
|
||||
"quality_score": 0.85,
|
||||
},
|
||||
)
|
||||
mock_job_storage.get_job.return_value = completed_job
|
||||
|
||||
response = client.get("/jobs/completed-job-456/result")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["jobId"] == "completed-job-456"
|
||||
assert data["status"] == "completed"
|
||||
assert "result" in data
|
||||
assert data["result"]["output_base64"] == "base64_encoded_image_data"
|
||||
assert data["result"]["quality_score"] == 0.85
|
||||
|
|
|
|||
Binary file not shown.
Binary file not shown.
|
|
@ -37,6 +37,28 @@ if TYPE_CHECKING:
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
"""Add custom command line options."""
|
||||
parser.addoption(
|
||||
"--gpu",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Run GPU tests (requires CUDA and loaded models)",
|
||||
)
|
||||
parser.addoption(
|
||||
"--llm",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Run LLM tests (requires LLM model loaded)",
|
||||
)
|
||||
parser.addoption(
|
||||
"--integration",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Run integration tests (requires Redis and running services)",
|
||||
)
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
"""Configure pytest markers."""
|
||||
config.addinivalue_line(
|
||||
|
|
@ -53,6 +75,26 @@ def pytest_configure(config):
|
|||
)
|
||||
|
||||
|
||||
def pytest_collection_modifyitems(config, items):
|
||||
"""Skip GPU/LLM/integration tests unless explicitly requested."""
|
||||
run_gpu = config.getoption("--gpu")
|
||||
run_llm = config.getoption("--llm")
|
||||
run_integration = config.getoption("--integration")
|
||||
|
||||
skip_gpu = pytest.mark.skip(reason="GPU tests require --gpu flag")
|
||||
skip_llm = pytest.mark.skip(reason="LLM tests require --llm flag")
|
||||
skip_redis = pytest.mark.skip(reason="Redis tests require --integration flag")
|
||||
|
||||
for item in items:
|
||||
# --gpu implies --llm and --integration
|
||||
if "gpu" in item.keywords and not run_gpu:
|
||||
item.add_marker(skip_gpu)
|
||||
if "llm" in item.keywords and not (run_llm or run_gpu):
|
||||
item.add_marker(skip_llm)
|
||||
if "redis" in item.keywords and not (run_integration or run_gpu):
|
||||
item.add_marker(skip_redis)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def event_loop():
|
||||
"""Create event loop for session-scoped async fixtures."""
|
||||
|
|
|
|||
|
|
@ -18,6 +18,9 @@ from typing import Any
|
|||
import httpx
|
||||
import pytest
|
||||
|
||||
# Mark all tests in this module as requiring LLM and GPU
|
||||
pytestmark = [pytest.mark.llm, pytest.mark.gpu, pytest.mark.slow]
|
||||
|
||||
# Service timeout for LLM inference
|
||||
SERVICE_TIMEOUT = 1800.0 # 30 minutes - testing for hangs vs slow generation
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue