240 lines
8.9 KiB
Python
240 lines
8.9 KiB
Python
"""Tests for the rate/load-gated dispatch engine."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import uuid as _uuid
|
|
from dataclasses import dataclass
|
|
from uuid import UUID
|
|
|
|
import pytest
|
|
|
|
from claire import events as ev
|
|
from claire.db import migrate, open_db
|
|
from claire.hlc import HLCGenerator
|
|
from claire import scheduler
|
|
from claire.web import service
|
|
|
|
|
|
def _setup() -> tuple:
|
|
conn = open_db(":memory:")
|
|
migrate(conn)
|
|
return conn, HLCGenerator("test-machine")
|
|
|
|
|
|
# --- host load / caps ------------------------------------------------------
|
|
|
|
|
|
def _alive_session(conn, gen, host: str, cwd: str) -> UUID:
|
|
sid = _uuid.uuid4()
|
|
ev.append(conn, gen, ev.SessionObserved(session_uuid=sid, host=host, cwd=cwd))
|
|
conn.execute("UPDATE sessions SET liveness = 'alive' WHERE uuid = ?", (str(sid),))
|
|
return sid
|
|
|
|
|
|
def test_host_load_counts_only_alive() -> None:
|
|
conn, gen = _setup()
|
|
_alive_session(conn, gen, "apricot", "/a")
|
|
_alive_session(conn, gen, "apricot", "/b")
|
|
# An observed-but-not-alive session must not count.
|
|
dead = _uuid.uuid4()
|
|
ev.append(conn, gen, ev.SessionObserved(session_uuid=dead, host="apricot", cwd="/c"))
|
|
load = scheduler.host_load(conn)
|
|
assert load.get("apricot") == 2
|
|
|
|
|
|
def test_host_has_capacity() -> None:
|
|
conn, gen = _setup()
|
|
for i in range(3):
|
|
_alive_session(conn, gen, "apricot", f"/p{i}")
|
|
assert scheduler.host_has_capacity(conn, "apricot", per_host_max=3) is False
|
|
assert scheduler.host_has_capacity(conn, "apricot", per_host_max=4) is True
|
|
assert scheduler.host_has_capacity(conn, "plum", per_host_max=3) is True
|
|
|
|
|
|
# --- fake rclaude for dispatch ---------------------------------------------
|
|
|
|
|
|
@dataclass
|
|
class _FakeSessionRow:
|
|
host: str
|
|
uuid: UUID
|
|
cwd: str
|
|
mtime_epoch: int
|
|
snippet: str = ""
|
|
|
|
|
|
class _FakeRclaude:
|
|
"""spawn() records the call and makes a new session discoverable."""
|
|
|
|
def __init__(self, *, spawn_ok: bool = True) -> None:
|
|
self._rows: list[_FakeSessionRow] = []
|
|
self.spawn_calls: list[dict] = []
|
|
self.send_calls: list[dict] = []
|
|
self._spawn_ok = spawn_ok
|
|
|
|
def list_sessions(self) -> list[_FakeSessionRow]:
|
|
return list(self._rows)
|
|
|
|
def list_tmux(self) -> list:
|
|
return []
|
|
|
|
def spawn(
|
|
self,
|
|
*,
|
|
host: str,
|
|
cwd: str,
|
|
mcp_config: str | None = None,
|
|
name: str | None = None,
|
|
) -> str:
|
|
from claire.rclaude import RclaudeError
|
|
if not self._spawn_ok:
|
|
raise RclaudeError("spawn failed")
|
|
self.spawn_calls.append(
|
|
{"host": host, "cwd": cwd, "mcp_config": mcp_config, "name": name}
|
|
)
|
|
self._rows.append(_FakeSessionRow(
|
|
host=host, uuid=_uuid.uuid4(), cwd=cwd, mtime_epoch=999,
|
|
))
|
|
return f"claude-tester-{len(self.spawn_calls)}"
|
|
|
|
# dispatch_task kicks the new session via `.send()` so Claude flushes its
|
|
# JSONL — record the call as a no-op.
|
|
def send(self, *, text: str, match: str, yes: bool = False, dry_run: bool = False): # noqa: ARG002
|
|
self.send_calls.append({"text": text, "match": match, "yes": yes})
|
|
return None
|
|
|
|
|
|
# dispatch_task stages a clare `.mcp.json` via a real ssh/subprocess stager.
|
|
# Tests must never touch the network — inject a stub instead.
|
|
def _no_mcp_stager(host: str) -> str | None: # noqa: ARG001
|
|
return None
|
|
|
|
|
|
def _project_with_task(conn, gen, *, priority: int = 2):
|
|
proj_id = _uuid.uuid4()
|
|
task_id = _uuid.uuid4()
|
|
ev.append(conn, gen, ev.ProjectCreated(project_id=proj_id, name="p"))
|
|
ev.append(conn, gen, ev.TaskAdded(
|
|
task_id=task_id, project_id=proj_id, title="t", priority=priority,
|
|
))
|
|
return task_id
|
|
|
|
|
|
def test_dispatch_success() -> None:
|
|
conn, gen = _setup()
|
|
task_id = _project_with_task(conn, gen)
|
|
rcl = _FakeRclaude()
|
|
result = service.dispatch_task(
|
|
conn, gen, task_id=task_id, host="plum", cwd="/work",
|
|
rclaude=rcl, discover_timeout_s=2,
|
|
mcp_stager=lambda _host: "/home/u/.local/share/claire/dispatch-mcp.json",
|
|
)
|
|
assert result.dispatched is True
|
|
assert result.reason == "ok"
|
|
assert result.session_uuid is not None
|
|
assert result.assignment_id is not None
|
|
assert len(rcl.spawn_calls) == 1
|
|
# The staged mcp_config path is wired through to the spawn.
|
|
assert rcl.spawn_calls[0]["mcp_config"] == (
|
|
"/home/u/.local/share/claire/dispatch-mcp.json"
|
|
)
|
|
# The session carries a display name slugified from the task title ("t").
|
|
assert rcl.spawn_calls[0]["name"] == "t"
|
|
# The kick + /remote-control sends both target the spawned tmux session
|
|
# name (unique), NOT the bare cwd-slug (which a sibling session shares).
|
|
tmux_name = "claude-tester-1" # _FakeRclaude.spawn's first return value
|
|
assert all(c["match"] == tmux_name for c in rcl.send_calls)
|
|
# A /remote-control registration send occurred with the task-slug name.
|
|
rc_sends = [c for c in rcl.send_calls if c["text"].startswith("/remote-control ")]
|
|
assert len(rc_sends) == 1
|
|
assert rc_sends[0]["text"] == "/remote-control t"
|
|
# Assignment row exists.
|
|
rows = conn.execute(
|
|
"SELECT COUNT(*) FROM assignments WHERE task_id = ?", (str(task_id),)
|
|
).fetchone()
|
|
assert rows[0] == 1
|
|
# The dispatched session is addressable: its sessions row carries the
|
|
# spawned tmux_name (a fresh pane has no `--resume`, so the pull loop
|
|
# would never map it — dispatch must record it directly).
|
|
sess = conn.execute(
|
|
"SELECT tmux_name FROM sessions WHERE uuid = ?", (result.session_uuid,)
|
|
).fetchone()
|
|
assert sess is not None
|
|
assert sess["tmux_name"] == tmux_name
|
|
|
|
|
|
def test_dispatch_refused_host_at_cap() -> None:
|
|
conn, gen = _setup()
|
|
task_id = _project_with_task(conn, gen)
|
|
# Fill plum to the default cap of 3.
|
|
for i in range(3):
|
|
_alive_session(conn, gen, "plum", f"/filler{i}")
|
|
rcl = _FakeRclaude()
|
|
result = service.dispatch_task(
|
|
conn, gen, task_id=task_id, host="plum", cwd="/work",
|
|
rclaude=rcl, discover_timeout_s=2, mcp_stager=_no_mcp_stager,
|
|
)
|
|
assert result.dispatched is False
|
|
assert "capacity" in result.reason
|
|
assert rcl.spawn_calls == [] # never spawned
|
|
|
|
|
|
def test_dispatch_refused_over_budget(monkeypatch, tmp_path) -> None:
|
|
# Point config at an isolated file with a tiny daily cap.
|
|
cfg_path = tmp_path / "config" / "claire" / "claire.toml"
|
|
monkeypatch.setenv("XDG_CONFIG_HOME", str(tmp_path / "config"))
|
|
cfg_path.parent.mkdir(parents=True, exist_ok=True)
|
|
cfg_path.write_text(
|
|
'machine_id = "m"\n\n[web]\nhost = "127.0.0.1"\nport = 8765\n'
|
|
"\n[budget]\ndaily_token_cap = 100\nlow_priority_floor = 0.8\n",
|
|
encoding="utf-8",
|
|
)
|
|
conn, gen = _setup()
|
|
task_id = _project_with_task(conn, gen, priority=3)
|
|
# Burn past the cap.
|
|
service.record_usage(conn, gen, source="nl", model="haiku",
|
|
input_tokens=120, output_tokens=10)
|
|
rcl = _FakeRclaude()
|
|
result = service.dispatch_task(
|
|
conn, gen, task_id=task_id, host="plum", cwd="/work",
|
|
rclaude=rcl, discover_timeout_s=2, mcp_stager=_no_mcp_stager,
|
|
)
|
|
assert result.dispatched is False
|
|
assert "cap" in result.reason.lower()
|
|
assert rcl.spawn_calls == []
|
|
|
|
|
|
def test_dispatch_low_priority_floor(monkeypatch, tmp_path) -> None:
|
|
monkeypatch.setenv("XDG_CONFIG_HOME", str(tmp_path / "config"))
|
|
cfg_path = tmp_path / "config" / "claire" / "claire.toml"
|
|
cfg_path.parent.mkdir(parents=True, exist_ok=True)
|
|
cfg_path.write_text(
|
|
'machine_id = "m"\n\n[web]\nhost = "127.0.0.1"\nport = 8765\n'
|
|
"\n[budget]\ndaily_token_cap = 1000\nlow_priority_floor = 0.5\n",
|
|
encoding="utf-8",
|
|
)
|
|
conn, gen = _setup()
|
|
# 600/1000 used = 60% — past the 50% floor.
|
|
service.record_usage(conn, gen, source="nl", model="haiku",
|
|
input_tokens=600, output_tokens=0)
|
|
rcl = _FakeRclaude()
|
|
# A P3 task is refused past the floor...
|
|
low = _project_with_task(conn, gen, priority=3)
|
|
r_low = service.dispatch_task(conn, gen, task_id=low, host="plum",
|
|
cwd="/w", rclaude=rcl, discover_timeout_s=2,
|
|
mcp_stager=_no_mcp_stager)
|
|
assert r_low.dispatched is False
|
|
assert "low-priority" in r_low.reason
|
|
# ...but a P0 task still goes through.
|
|
proj_id = _uuid.uuid4()
|
|
p0 = _uuid.uuid4()
|
|
ev.append(conn, gen, ev.TaskAdded(task_id=p0, project_id=conn.execute(
|
|
"SELECT project_id FROM tasks LIMIT 1").fetchone()[0] and
|
|
UUID(conn.execute("SELECT project_id FROM tasks LIMIT 1").fetchone()[0]),
|
|
title="urgent", priority=0))
|
|
r_p0 = service.dispatch_task(conn, gen, task_id=p0, host="plum",
|
|
cwd="/w", rclaude=rcl, discover_timeout_s=2,
|
|
mcp_stager=_no_mcp_stager)
|
|
assert r_p0.dispatched is True
|
|
# Degraded path: stager returned None → spawn carried no mcp_config.
|
|
assert rcl.spawn_calls[-1]["mcp_config"] is None
|