139 lines
4.2 KiB
Python
139 lines
4.2 KiB
Python
"""Tests for the chat SSE stream endpoint."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import threading
|
|
import time
|
|
from pathlib import Path
|
|
|
|
import pytest
|
|
from fastapi.testclient import TestClient
|
|
|
|
|
|
@pytest.fixture
|
|
def client(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> TestClient:
|
|
monkeypatch.setenv("XDG_DATA_HOME", str(tmp_path / "data"))
|
|
monkeypatch.setenv("XDG_CONFIG_HOME", str(tmp_path / "config"))
|
|
from clare.web.app import create_app
|
|
|
|
return TestClient(create_app())
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Content-type + auth/scope validation
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def test_stream_content_type_is_event_stream(client: TestClient) -> None:
|
|
with client.stream(
|
|
"GET",
|
|
"/chat/stream",
|
|
params={"scope": "orchestrator", "after_rowid": 0, "poll_interval": 0.05},
|
|
) as resp:
|
|
assert resp.status_code == 200
|
|
assert resp.headers["content-type"].startswith("text/event-stream")
|
|
# Close immediately — we just wanted the headers.
|
|
|
|
|
|
def test_stream_bad_scope_returns_400(client: TestClient) -> None:
|
|
r = client.get(
|
|
"/chat/stream",
|
|
params={"scope": "garbage", "after_rowid": 0},
|
|
)
|
|
assert r.status_code == 400
|
|
|
|
|
|
def test_stream_invalid_scope_ref_returns_400(client: TestClient) -> None:
|
|
# orchestrator with a scope_ref is invalid input.
|
|
r = client.get(
|
|
"/chat/stream",
|
|
params={
|
|
"scope": "orchestrator", "scope_ref": "should-not-be-here",
|
|
"after_rowid": 0,
|
|
},
|
|
)
|
|
assert r.status_code == 400
|
|
|
|
|
|
def test_stream_unknown_project_returns_404(client: TestClient) -> None:
|
|
r = client.get(
|
|
"/chat/stream",
|
|
params={"scope": "project", "scope_ref": "nope", "after_rowid": 0},
|
|
)
|
|
assert r.status_code == 404
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Streaming behaviour
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _collect_until(resp, marker: str, timeout: float = 3.0) -> str:
|
|
"""Read raw text from an SSE response until `marker` shows up or timeout."""
|
|
deadline = time.time() + timeout
|
|
buf = ""
|
|
for chunk in resp.iter_text():
|
|
buf += chunk
|
|
if marker in buf:
|
|
return buf
|
|
if time.time() > deadline:
|
|
break
|
|
return buf
|
|
|
|
|
|
def test_stream_yields_existing_backlog(client: TestClient) -> None:
|
|
"""Initial flush: messages already past after_rowid arrive immediately."""
|
|
client.post(
|
|
"/api/v1/chat",
|
|
json={"scope": "orchestrator", "scope_ref": None, "body": "hello-backlog"},
|
|
)
|
|
with client.stream(
|
|
"GET",
|
|
"/chat/stream",
|
|
params={
|
|
"scope": "orchestrator", "after_rowid": 0, "poll_interval": 0.05,
|
|
},
|
|
) as resp:
|
|
assert resp.status_code == 200
|
|
text = _collect_until(resp, "hello-backlog", timeout=2.0)
|
|
assert "event: chat" in text
|
|
assert "hello-backlog" in text
|
|
|
|
|
|
def test_stream_pushes_new_message(client: TestClient) -> None:
|
|
"""A message posted AFTER the connection opens is pushed within ~2s."""
|
|
# Snapshot baseline rowid.
|
|
r0 = client.post(
|
|
"/api/v1/chat",
|
|
json={"scope": "orchestrator", "scope_ref": None, "body": "baseline"},
|
|
)
|
|
baseline = r0.json()["user_message"]["rowid"]
|
|
|
|
# Background poster: fires a message a moment after we start streaming.
|
|
def post_later() -> None:
|
|
time.sleep(0.3)
|
|
client.post(
|
|
"/api/v1/chat",
|
|
json={
|
|
"scope": "orchestrator", "scope_ref": None,
|
|
"body": "live-update-marker",
|
|
},
|
|
)
|
|
|
|
poster = threading.Thread(target=post_later, daemon=True)
|
|
poster.start()
|
|
|
|
with client.stream(
|
|
"GET",
|
|
"/chat/stream",
|
|
params={
|
|
"scope": "orchestrator", "after_rowid": baseline,
|
|
"poll_interval": 0.1,
|
|
},
|
|
) as resp:
|
|
assert resp.status_code == 200
|
|
text = _collect_until(resp, "live-update-marker", timeout=4.0)
|
|
|
|
poster.join(timeout=1.0)
|
|
assert "live-update-marker" in text
|
|
assert "event: chat" in text
|