190 lines
6.2 KiB
Python
190 lines
6.2 KiB
Python
"""Smoke tests for the orchestrator MCP server and turn registry.
|
|
|
|
We don't speak the full MCP protocol in unit tests — that's reserved for
|
|
end-to-end tests once a real Claude session connects. Here we just verify:
|
|
- The FastMCP server builds and registers every expected tool.
|
|
- The turn registry round-trips a body via deliver_reply ↔ wait_for_reply.
|
|
- `submit_chat_reply` (called as a plain function via the FastMCP
|
|
handler) honors the turn_id contract.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import threading
|
|
import time
|
|
|
|
import pytest
|
|
|
|
from claire.orchestrator import mcp_server, turns
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Turn registry
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def test_wait_for_reply_returns_delivered_body() -> None:
|
|
tid = "abc12345"
|
|
turns.register(tid)
|
|
captured: list[str | None] = []
|
|
|
|
def waiter() -> None:
|
|
captured.append(turns.wait_for_reply(tid, timeout_s=2.0))
|
|
|
|
t = threading.Thread(target=waiter)
|
|
t.start()
|
|
time.sleep(0.05)
|
|
assert turns.deliver_reply(tid, body="hello!")
|
|
t.join(timeout=2.0)
|
|
assert captured == ["hello!"]
|
|
assert turns.pending_count() == 0
|
|
|
|
|
|
def test_wait_for_reply_times_out() -> None:
|
|
tid = "def67890"
|
|
turns.register(tid)
|
|
out = turns.wait_for_reply(tid, timeout_s=0.05)
|
|
assert out is None
|
|
assert turns.pending_count() == 0
|
|
|
|
|
|
def test_deliver_unknown_turn_returns_false() -> None:
|
|
assert not turns.deliver_reply("no-such-turn", body="x")
|
|
|
|
|
|
def test_register_duplicate_turn_raises() -> None:
|
|
tid = "dup-turn-12345"
|
|
turns.register(tid)
|
|
try:
|
|
with pytest.raises(RuntimeError, match="already registered"):
|
|
turns.register(tid)
|
|
finally:
|
|
# Drain the slot so subsequent tests start clean.
|
|
turns.wait_for_reply(tid, timeout_s=0.0)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Turn-id prefix extractor
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def test_extract_turn_id_from_prefix() -> None:
|
|
assert mcp_server.extract_turn_id("[turn:abc12345] hello") == "abc12345"
|
|
|
|
|
|
def test_extract_turn_id_returns_none_when_missing() -> None:
|
|
assert mcp_server.extract_turn_id("just a message") is None
|
|
|
|
|
|
def test_extract_turn_id_handles_leading_whitespace() -> None:
|
|
assert mcp_server.extract_turn_id(" [turn:abcd9876] body") == "abcd9876"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# FastMCP server: tool registration
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_server_registers_expected_tools() -> None:
|
|
server = mcp_server.build_server()
|
|
listed = await server.list_tools()
|
|
names = {t.name for t in listed}
|
|
expected = {
|
|
"create_project", "add_task", "list_tasks", "create_assignment",
|
|
"broadcast", "pull", "status", "help",
|
|
"list_recent_events", "search_chat_messages", "get_session",
|
|
"summarize_project", "suggest_assignments",
|
|
"send_to_session", "submit_chat_reply",
|
|
}
|
|
missing = expected - names
|
|
assert not missing, f"missing tools: {missing}"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_submit_chat_reply_without_turn_id_errors_cleanly() -> None:
|
|
server = mcp_server.build_server()
|
|
result = await server.call_tool(
|
|
"submit_chat_reply", {"body": "no turn id here"},
|
|
)
|
|
# FastMCP wraps the dict return — find it.
|
|
payload = _unwrap(result)
|
|
assert payload["ok"] is False
|
|
assert "turn_id" in payload["error"]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_tool_call_posts_intent_log_to_orchestrator_chat(
|
|
tmp_path, monkeypatch,
|
|
) -> None:
|
|
"""Every MCP tool should log a `→ tool(args)` system message before running.
|
|
|
|
Uses XDG_*HOME isolation so the log writes go to a tmp DB.
|
|
"""
|
|
monkeypatch.setenv("XDG_DATA_HOME", str(tmp_path / "data"))
|
|
monkeypatch.setenv("XDG_CONFIG_HOME", str(tmp_path / "config"))
|
|
# Reset the cached default server so it sees the new env.
|
|
monkeypatch.setattr(mcp_server, "_DEFAULT_SERVER", None)
|
|
|
|
server = mcp_server.build_server()
|
|
await server.call_tool("create_project", {"name": "alpha", "goal": "ship"})
|
|
|
|
# Inspect chat_messages for the tool-call log.
|
|
from claire.db import migrate, open_db
|
|
conn = open_db()
|
|
migrate(conn)
|
|
try:
|
|
rows = conn.execute(
|
|
"SELECT body, meta FROM chat_messages "
|
|
"WHERE scope='orchestrator' AND role='system' "
|
|
"ORDER BY rowid"
|
|
).fetchall()
|
|
finally:
|
|
conn.close()
|
|
bodies = [r["body"] for r in rows]
|
|
# The tool-call log comes before the side-effect fan-out.
|
|
assert any(b.startswith("→ create_project(") for b in bodies)
|
|
# The actual side-effect ("Project created: alpha") still appears too.
|
|
assert any("Project created" in b for b in bodies)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_submit_chat_reply_with_turn_id_unblocks_waiter() -> None:
|
|
server = mcp_server.build_server()
|
|
tid = "live-turn-1"
|
|
turns.register(tid)
|
|
captured: list[str | None] = []
|
|
|
|
def waiter() -> None:
|
|
captured.append(turns.wait_for_reply(tid, timeout_s=2.0))
|
|
|
|
t = threading.Thread(target=waiter)
|
|
t.start()
|
|
time.sleep(0.05)
|
|
|
|
result = await server.call_tool(
|
|
"submit_chat_reply", {"body": "all done", "turn_id": tid},
|
|
)
|
|
payload = _unwrap(result)
|
|
assert payload["ok"] is True
|
|
t.join(timeout=2.0)
|
|
assert captured == ["all done"]
|
|
|
|
|
|
def _unwrap(result) -> dict:
|
|
"""Pull the structured payload out of a FastMCP call_tool() return.
|
|
|
|
FastMCP returns a tuple (content_blocks, structured_content); both
|
|
layouts have been seen across SDK versions. This handles both.
|
|
"""
|
|
if isinstance(result, tuple) and len(result) == 2:
|
|
_, structured = result
|
|
if isinstance(structured, dict):
|
|
return structured
|
|
# Fallback: a list of content blocks; first block's text holds JSON.
|
|
if isinstance(result, list) and result:
|
|
import json
|
|
text = getattr(result[0], "text", None)
|
|
if text is not None:
|
|
return json.loads(text)
|
|
raise AssertionError(f"unrecognized FastMCP call_tool result: {result!r}")
|