claire/tests/test_dispatch.py

192 lines
6.6 KiB
Python
Raw Normal View History

"""Tests for the rate/load-gated dispatch engine."""
from __future__ import annotations
import uuid as _uuid
from dataclasses import dataclass
from uuid import UUID
import pytest
from claire import events as ev
from claire.db import migrate, open_db
from claire.hlc import HLCGenerator
from claire import scheduler
from claire.web import service
def _setup() -> tuple:
conn = open_db(":memory:")
migrate(conn)
return conn, HLCGenerator("test-machine")
# --- host load / caps ------------------------------------------------------
def _alive_session(conn, gen, host: str, cwd: str) -> UUID:
sid = _uuid.uuid4()
ev.append(conn, gen, ev.SessionObserved(session_uuid=sid, host=host, cwd=cwd))
conn.execute("UPDATE sessions SET liveness = 'alive' WHERE uuid = ?", (str(sid),))
return sid
def test_host_load_counts_only_alive() -> None:
conn, gen = _setup()
_alive_session(conn, gen, "apricot", "/a")
_alive_session(conn, gen, "apricot", "/b")
# An observed-but-not-alive session must not count.
dead = _uuid.uuid4()
ev.append(conn, gen, ev.SessionObserved(session_uuid=dead, host="apricot", cwd="/c"))
load = scheduler.host_load(conn)
assert load.get("apricot") == 2
def test_host_has_capacity() -> None:
conn, gen = _setup()
for i in range(3):
_alive_session(conn, gen, "apricot", f"/p{i}")
assert scheduler.host_has_capacity(conn, "apricot", per_host_max=3) is False
assert scheduler.host_has_capacity(conn, "apricot", per_host_max=4) is True
assert scheduler.host_has_capacity(conn, "plum", per_host_max=3) is True
# --- fake rclaude for dispatch ---------------------------------------------
@dataclass
class _FakeSessionRow:
host: str
uuid: UUID
cwd: str
mtime_epoch: int
snippet: str = ""
class _FakeRclaude:
"""spawn() records the call and makes a new session discoverable."""
def __init__(self, *, spawn_ok: bool = True) -> None:
self._rows: list[_FakeSessionRow] = []
self.spawn_calls: list[dict] = []
self._spawn_ok = spawn_ok
def list_sessions(self) -> list[_FakeSessionRow]:
return list(self._rows)
def list_tmux(self) -> list:
return []
def spawn(self, *, host: str, cwd: str, mcp_config: str | None = None) -> str:
from claire.rclaude import RclaudeError
if not self._spawn_ok:
raise RclaudeError("spawn failed")
self.spawn_calls.append({"host": host, "cwd": cwd})
self._rows.append(_FakeSessionRow(
host=host, uuid=_uuid.uuid4(), cwd=cwd, mtime_epoch=999,
))
return f"claude-tester-{len(self.spawn_calls)}"
def _project_with_task(conn, gen, *, priority: int = 2):
proj_id = _uuid.uuid4()
task_id = _uuid.uuid4()
ev.append(conn, gen, ev.ProjectCreated(project_id=proj_id, name="p"))
ev.append(conn, gen, ev.TaskAdded(
task_id=task_id, project_id=proj_id, title="t", priority=priority,
))
return task_id
def test_dispatch_success() -> None:
conn, gen = _setup()
task_id = _project_with_task(conn, gen)
rcl = _FakeRclaude()
result = service.dispatch_task(
conn, gen, task_id=task_id, host="plum", cwd="/work",
rclaude=rcl, discover_timeout_s=2,
)
assert result.dispatched is True
assert result.reason == "ok"
assert result.session_uuid is not None
assert result.assignment_id is not None
assert len(rcl.spawn_calls) == 1
# Assignment row exists.
rows = conn.execute(
"SELECT COUNT(*) FROM assignments WHERE task_id = ?", (str(task_id),)
).fetchone()
assert rows[0] == 1
def test_dispatch_refused_host_at_cap() -> None:
conn, gen = _setup()
task_id = _project_with_task(conn, gen)
# Fill plum to the default cap of 3.
for i in range(3):
_alive_session(conn, gen, "plum", f"/filler{i}")
rcl = _FakeRclaude()
result = service.dispatch_task(
conn, gen, task_id=task_id, host="plum", cwd="/work",
rclaude=rcl, discover_timeout_s=2,
)
assert result.dispatched is False
assert "capacity" in result.reason
assert rcl.spawn_calls == [] # never spawned
def test_dispatch_refused_over_budget(monkeypatch, tmp_path) -> None:
# Point config at an isolated file with a tiny daily cap.
cfg_path = tmp_path / "config" / "claire" / "claire.toml"
monkeypatch.setenv("XDG_CONFIG_HOME", str(tmp_path / "config"))
cfg_path.parent.mkdir(parents=True, exist_ok=True)
cfg_path.write_text(
'machine_id = "m"\n\n[web]\nhost = "127.0.0.1"\nport = 8765\n'
"\n[budget]\ndaily_token_cap = 100\nlow_priority_floor = 0.8\n",
encoding="utf-8",
)
conn, gen = _setup()
task_id = _project_with_task(conn, gen, priority=3)
# Burn past the cap.
service.record_usage(conn, gen, source="nl", model="haiku",
input_tokens=120, output_tokens=10)
rcl = _FakeRclaude()
result = service.dispatch_task(
conn, gen, task_id=task_id, host="plum", cwd="/work",
rclaude=rcl, discover_timeout_s=2,
)
assert result.dispatched is False
assert "cap" in result.reason.lower()
assert rcl.spawn_calls == []
def test_dispatch_low_priority_floor(monkeypatch, tmp_path) -> None:
monkeypatch.setenv("XDG_CONFIG_HOME", str(tmp_path / "config"))
cfg_path = tmp_path / "config" / "claire" / "claire.toml"
cfg_path.parent.mkdir(parents=True, exist_ok=True)
cfg_path.write_text(
'machine_id = "m"\n\n[web]\nhost = "127.0.0.1"\nport = 8765\n'
"\n[budget]\ndaily_token_cap = 1000\nlow_priority_floor = 0.5\n",
encoding="utf-8",
)
conn, gen = _setup()
# 600/1000 used = 60% — past the 50% floor.
service.record_usage(conn, gen, source="nl", model="haiku",
input_tokens=600, output_tokens=0)
rcl = _FakeRclaude()
# A P3 task is refused past the floor...
low = _project_with_task(conn, gen, priority=3)
r_low = service.dispatch_task(conn, gen, task_id=low, host="plum",
cwd="/w", rclaude=rcl, discover_timeout_s=2)
assert r_low.dispatched is False
assert "low-priority" in r_low.reason
# ...but a P0 task still goes through.
proj_id = _uuid.uuid4()
p0 = _uuid.uuid4()
ev.append(conn, gen, ev.TaskAdded(task_id=p0, project_id=conn.execute(
"SELECT project_id FROM tasks LIMIT 1").fetchone()[0] and
UUID(conn.execute("SELECT project_id FROM tasks LIMIT 1").fetchone()[0]),
title="urgent", priority=0))
r_p0 = service.dispatch_task(conn, gen, task_id=p0, host="plum",
cwd="/w", rclaude=rcl, discover_timeout_s=2)
assert r_p0.dispatched is True