diff --git a/src/claire/agent/supervisor.py b/src/claire/agent/supervisor.py new file mode 100644 index 0000000..d912df2 --- /dev/null +++ b/src/claire/agent/supervisor.py @@ -0,0 +1,143 @@ +"""Worker-session supervisor — detect wedged/orphaned LOCAL sessions, recover. + +Reuses rclaude's `list_sessions` (JSONL mtime) + `list_tmux` (live panes) — +NEVER `capture-pane` (apricot's tmux heap-corrupts on it). A session is: + - **wedged**: has a live tmux pane (matched via `resumed_uuid`) but its JSONL + mtime is stale beyond `wedge_threshold_s` (alive process, not progressing). + - **orphaned**: an on-disk session with no live pane. + +Recovery ladder (all gated): kick (a one-line ping) → escalate (emit an +agent-status event so plum's fleet view flags it) → respawn (OFF unless +`supervisor_allow_respawn`). Only LOCAL sessions are touched (rclaude tags the +calling machine's own sessions/panes `host == "local"`). +""" + +from __future__ import annotations + +import asyncio +import logging +import time +from pathlib import Path + +logger = logging.getLogger(__name__) + +# session_uuid -> consecutive wedge detections (reset when it recovers). +_escalation_state: dict[str, int] = {} + +_LOCAL = "local" + + +def detect_wedged_and_orphaned(sessions, tmux_rows, *, wedge_threshold_s, now): + """Pure classification of LOCAL sessions. Returns (wedged, orphaned) lists + of SessionRow. `now` is epoch seconds. + + Join key is `TmuxRow.resumed_uuid == SessionRow.uuid`. If no tmux row + carries a `resumed_uuid` (older rclaude omits the column), correlation is + impossible, so nothing is classified wedged (we never act blind). + """ + live_uuids = {str(r.resumed_uuid) for r in tmux_rows if getattr(r, "resumed_uuid", None)} + wedged, orphaned = [], [] + for s in sessions: + if s.host != _LOCAL: + continue # supervise only this machine's own sessions + if str(s.uuid) in live_uuids: + if now - s.mtime_epoch > wedge_threshold_s: + wedged.append(s) + else: + orphaned.append(s) + return wedged, orphaned + + +async def supervisor_loop(*, config_path: Path | None, db_path: Path | None) -> None: + from ..config import load_or_init + from ..rclaude import Rclaude, RclaudeError + + cfg = load_or_init(config_path) + if not cfg.agent.supervisor_enable: + logger.info("agent supervisor disabled (agent.supervisor_enable=false)") + return + threshold = cfg.agent.wedge_threshold_s + allow_respawn = cfg.agent.supervisor_allow_respawn + poll = min(60, max(20, threshold // 3)) + rcl = Rclaude() + logger.info( + "agent supervisor loop enabled (poll %ds, wedge>%ds, respawn=%s)", + poll, threshold, allow_respawn, + ) + + while True: + try: + await asyncio.sleep(poll) + sessions = await asyncio.to_thread(rcl.list_sessions) + tmux_rows = await asyncio.to_thread(rcl.list_tmux) + name_by_uuid = { + str(r.resumed_uuid): r.session_name + for r in tmux_rows + if getattr(r, "resumed_uuid", None) + } + wedged, orphaned = detect_wedged_and_orphaned( + sessions, tmux_rows, wedge_threshold_s=threshold, now=time.time() + ) + if orphaned: + logger.info("supervisor: %d orphaned local session(s)", len(orphaned)) + + for s in wedged: + key = str(s.uuid) + n = _escalation_state.get(key, 0) + 1 + _escalation_state[key] = n + logger.warning( + "supervisor: session %s wedged (mtime stale, x%d)", key[:8], n + ) + name = name_by_uuid.get(key) + if n == 1 and name: + # Kick: a gentle ping in case it's waiting on input. + try: + await asyncio.to_thread( + rcl.send, + text="[supervisor] still working? reply one line.", + match=name, + yes=True, + ) + except RclaudeError as exc: + logger.warning("supervisor kick failed for %s: %s", key[:8], exc) + elif n == 2: + await asyncio.to_thread(_escalate, db_path, cfg.machine_id, s) + elif allow_respawn and n >= 3 and name: + logger.warning("supervisor: respawning wedged %s", key[:8]) + try: + await asyncio.to_thread(rcl.kill, match=name, yes=True) + except RclaudeError as exc: + logger.warning("supervisor respawn-kill failed: %s", exc) + + # Reset escalation counters for sessions that recovered. + wedged_keys = {str(s.uuid) for s in wedged} + for key in list(_escalation_state): + if key not in wedged_keys: + _escalation_state.pop(key, None) + except asyncio.CancelledError: + return + except Exception as exc: # noqa: BLE001 + logger.warning("supervisor loop tick raised: %s", exc) + + +def _escalate(db_path, machine_id: str, session) -> None: + """Emit an agent-status event so plum's fleet view flags the wedge. Reuses + the existing AgentStatusReported (source='triage') — it syncs to plum.""" + from ..db import migrate, open_db + from ..events import AgentStatusReported, append + from ..hlc import HLCGenerator + + conn = open_db(db_path) + try: + migrate(conn) + append( + conn, + HLCGenerator(machine_id), + AgentStatusReported( + session_uuid=session.uuid, + summary="[supervisor] wedged — no JSONL progress; kicked, awaiting reply", + source="triage", + ), + ) + finally: + conn.close() diff --git a/src/claire/agent/telemetry.py b/src/claire/agent/telemetry.py new file mode 100644 index 0000000..b9ca424 --- /dev/null +++ b/src/claire/agent/telemetry.py @@ -0,0 +1,99 @@ +"""Host telemetry loop — sample CPU/mem/load/disk, emit HostTelemetryReported. + +Events sync to plum and surface in `fleet_load` so dispatch sees real host +capacity, not just live-session counts. One snapshot row per host (last-write- +wins); sampled at `agent.telemetry_interval_s`. +""" + +from __future__ import annotations + +import asyncio +import logging +from dataclasses import dataclass +from pathlib import Path + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class HostMetrics: + cpu_percent: float + mem_used_bytes: int + mem_total_bytes: int + load_1: float + load_5: float + load_15: float + disk_used_bytes: int + disk_total_bytes: int + + +def sample_host_metrics() -> HostMetrics: + """Blocking psutil sample (cpu_percent needs a measurement window). Run in + a thread so the event loop isn't stalled.""" + import psutil + + vm = psutil.virtual_memory() + du = psutil.disk_usage("/") + try: + load_1, load_5, load_15 = psutil.getloadavg() + except (OSError, AttributeError): # not available on some platforms + load_1 = load_5 = load_15 = 0.0 + return HostMetrics( + cpu_percent=psutil.cpu_percent(interval=0.5), + mem_used_bytes=int(vm.used), + mem_total_bytes=int(vm.total), + load_1=float(load_1), + load_5=float(load_5), + load_15=float(load_15), + disk_used_bytes=int(du.used), + disk_total_bytes=int(du.total), + ) + + +async def telemetry_loop(*, config_path: Path | None, db_path: Path | None) -> None: + from ..config import load_or_init + + cfg = load_or_init(config_path) + interval = cfg.agent.telemetry_interval_s + host = cfg.this_host_label() + machine_id = cfg.machine_id + logger.info( + "agent telemetry loop enabled (every %ds, host=%s)", interval, host + ) + + def _sample_and_persist() -> HostMetrics: + from ..db import migrate, open_db + from ..events import HostTelemetryReported, append + from ..hlc import HLCGenerator + + m = sample_host_metrics() + conn = open_db(db_path) + try: + migrate(conn) + append( + conn, + HLCGenerator(machine_id), + HostTelemetryReported( + host=host, + cpu_percent=m.cpu_percent, + mem_used_bytes=m.mem_used_bytes, + mem_total_bytes=m.mem_total_bytes, + load_1=m.load_1, + load_5=m.load_5, + load_15=m.load_15, + disk_used_bytes=m.disk_used_bytes, + disk_total_bytes=m.disk_total_bytes, + ), + ) + finally: + conn.close() + return m + + while True: + try: + await asyncio.sleep(interval) + await asyncio.to_thread(_sample_and_persist) + except asyncio.CancelledError: + return + except Exception as exc: # noqa: BLE001 + logger.warning("telemetry loop tick raised: %s", exc)