"""Tests for the chat SSE stream endpoint.""" from __future__ import annotations import threading import time from pathlib import Path import pytest from fastapi.testclient import TestClient @pytest.fixture def client(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> TestClient: monkeypatch.setenv("XDG_DATA_HOME", str(tmp_path / "data")) monkeypatch.setenv("XDG_CONFIG_HOME", str(tmp_path / "config")) from claire.web.app import create_app return TestClient(create_app()) # --------------------------------------------------------------------------- # Content-type + auth/scope validation # --------------------------------------------------------------------------- def test_stream_content_type_is_event_stream(client: TestClient) -> None: with client.stream( "GET", "/chat/stream", params={ "scope": "orchestrator", "after_rowid": 0, "poll_interval": 0.05, "max_duration": 0.2, }, ) as resp: assert resp.status_code == 200 assert resp.headers["content-type"].startswith("text/event-stream") # Drain — generator exits when max_duration elapses. for _ in resp.iter_text(): pass def test_stream_bad_scope_returns_400(client: TestClient) -> None: r = client.get( "/chat/stream", params={"scope": "garbage", "after_rowid": 0}, ) assert r.status_code == 400 def test_stream_invalid_scope_ref_returns_400(client: TestClient) -> None: # orchestrator with a scope_ref is invalid input. r = client.get( "/chat/stream", params={ "scope": "orchestrator", "scope_ref": "should-not-be-here", "after_rowid": 0, }, ) assert r.status_code == 400 def test_stream_unknown_project_returns_404(client: TestClient) -> None: r = client.get( "/chat/stream", params={"scope": "project", "scope_ref": "nope", "after_rowid": 0}, ) assert r.status_code == 404 # --------------------------------------------------------------------------- # Streaming behaviour # --------------------------------------------------------------------------- def _collect_until(resp, marker: str, timeout: float = 3.0) -> str: """Read raw text from an SSE response until `marker` shows up or stream ends.""" deadline = time.time() + timeout buf = "" for chunk in resp.iter_text(): buf += chunk if marker in buf: return buf if time.time() > deadline: break return buf # stream ended (max_duration elapsed) — return what we got def test_stream_yields_existing_backlog(client: TestClient) -> None: """Initial flush: rowids already past after_rowid arrive immediately. Post-R6 the SSE payload is `{"new_rowids": [...]}` JSON, not HTML. The React client uses the event as a wake-up signal and refetches via `GET /api/v1/chat`. """ r = client.post( "/api/v1/chat", json={"scope": "orchestrator", "scope_ref": None, "body": "hello-backlog"}, ) user_rowid = r.json()["user_message"]["rowid"] marker = f"{user_rowid}" with client.stream( "GET", "/chat/stream", params={ "scope": "orchestrator", "after_rowid": 0, "poll_interval": 0.05, "max_duration": 1.0, }, ) as resp: assert resp.status_code == 200 text = _collect_until(resp, marker, timeout=2.0) assert "event: chat" in text assert "new_rowids" in text assert marker in text def test_stream_pushes_new_message(client: TestClient) -> None: """A message posted AFTER the connection opens is pushed within ~2s.""" # Snapshot baseline rowid. r0 = client.post( "/api/v1/chat", json={"scope": "orchestrator", "scope_ref": None, "body": "baseline"}, ) baseline = r0.json()["user_message"]["rowid"] target_rowid: list[int] = [] def post_later() -> None: time.sleep(0.3) r = client.post( "/api/v1/chat", json={ "scope": "orchestrator", "scope_ref": None, "body": "live-update-marker", }, ) target_rowid.append(r.json()["user_message"]["rowid"]) poster = threading.Thread(target=post_later, daemon=True) poster.start() with client.stream( "GET", "/chat/stream", params={ "scope": "orchestrator", "after_rowid": baseline, "poll_interval": 0.1, "max_duration": 3.0, }, ) as resp: assert resp.status_code == 200 poster.join(timeout=2.0) assert target_rowid, "background poster failed to deliver" text = _collect_until(resp, str(target_rowid[0]), timeout=4.0) assert "event: chat" in text assert "new_rowids" in text assert str(target_rowid[0]) in text