claire/tests/test_chat_stream.py
autocommit 6d212b7dbe refactor(testing-test): ♻️ Update test imports to use claire instead of clare in package references
Co-Authored-By: Lilith Autocommit <noreply@atlilith.com>
2026-05-20 19:54:05 -07:00

157 lines
4.8 KiB
Python

"""Tests for the chat SSE stream endpoint."""
from __future__ import annotations
import threading
import time
from pathlib import Path
import pytest
from fastapi.testclient import TestClient
@pytest.fixture
def client(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> TestClient:
monkeypatch.setenv("XDG_DATA_HOME", str(tmp_path / "data"))
monkeypatch.setenv("XDG_CONFIG_HOME", str(tmp_path / "config"))
from claire.web.app import create_app
return TestClient(create_app())
# ---------------------------------------------------------------------------
# Content-type + auth/scope validation
# ---------------------------------------------------------------------------
def test_stream_content_type_is_event_stream(client: TestClient) -> None:
with client.stream(
"GET",
"/chat/stream",
params={
"scope": "orchestrator", "after_rowid": 0,
"poll_interval": 0.05, "max_duration": 0.2,
},
) as resp:
assert resp.status_code == 200
assert resp.headers["content-type"].startswith("text/event-stream")
# Drain — generator exits when max_duration elapses.
for _ in resp.iter_text():
pass
def test_stream_bad_scope_returns_400(client: TestClient) -> None:
r = client.get(
"/chat/stream",
params={"scope": "garbage", "after_rowid": 0},
)
assert r.status_code == 400
def test_stream_invalid_scope_ref_returns_400(client: TestClient) -> None:
# orchestrator with a scope_ref is invalid input.
r = client.get(
"/chat/stream",
params={
"scope": "orchestrator", "scope_ref": "should-not-be-here",
"after_rowid": 0,
},
)
assert r.status_code == 400
def test_stream_unknown_project_returns_404(client: TestClient) -> None:
r = client.get(
"/chat/stream",
params={"scope": "project", "scope_ref": "nope", "after_rowid": 0},
)
assert r.status_code == 404
# ---------------------------------------------------------------------------
# Streaming behaviour
# ---------------------------------------------------------------------------
def _collect_until(resp, marker: str, timeout: float = 3.0) -> str:
"""Read raw text from an SSE response until `marker` shows up or stream ends."""
deadline = time.time() + timeout
buf = ""
for chunk in resp.iter_text():
buf += chunk
if marker in buf:
return buf
if time.time() > deadline:
break
return buf # stream ended (max_duration elapsed) — return what we got
def test_stream_yields_existing_backlog(client: TestClient) -> None:
"""Initial flush: rowids already past after_rowid arrive immediately.
Post-R6 the SSE payload is `{"new_rowids": [...]}` JSON, not HTML.
The React client uses the event as a wake-up signal and refetches via
`GET /api/v1/chat`.
"""
r = client.post(
"/api/v1/chat",
json={"scope": "orchestrator", "scope_ref": None, "body": "hello-backlog"},
)
user_rowid = r.json()["user_message"]["rowid"]
marker = f"{user_rowid}"
with client.stream(
"GET",
"/chat/stream",
params={
"scope": "orchestrator", "after_rowid": 0,
"poll_interval": 0.05, "max_duration": 1.0,
},
) as resp:
assert resp.status_code == 200
text = _collect_until(resp, marker, timeout=2.0)
assert "event: chat" in text
assert "new_rowids" in text
assert marker in text
def test_stream_pushes_new_message(client: TestClient) -> None:
"""A message posted AFTER the connection opens is pushed within ~2s."""
# Snapshot baseline rowid.
r0 = client.post(
"/api/v1/chat",
json={"scope": "orchestrator", "scope_ref": None, "body": "baseline"},
)
baseline = r0.json()["user_message"]["rowid"]
target_rowid: list[int] = []
def post_later() -> None:
time.sleep(0.3)
r = client.post(
"/api/v1/chat",
json={
"scope": "orchestrator", "scope_ref": None,
"body": "live-update-marker",
},
)
target_rowid.append(r.json()["user_message"]["rowid"])
poster = threading.Thread(target=post_later, daemon=True)
poster.start()
with client.stream(
"GET",
"/chat/stream",
params={
"scope": "orchestrator", "after_rowid": baseline,
"poll_interval": 0.1, "max_duration": 3.0,
},
) as resp:
assert resp.status_code == 200
poster.join(timeout=2.0)
assert target_rowid, "background poster failed to deliver"
text = _collect_until(resp, str(target_rowid[0]), timeout=4.0)
assert "event: chat" in text
assert "new_rowids" in text
assert str(target_rowid[0]) in text