225 lines
7.3 KiB
Python
225 lines
7.3 KiB
Python
from __future__ import annotations
|
|
|
|
import tomllib
|
|
from pathlib import Path
|
|
|
|
import pytest
|
|
from pydantic import ValidationError
|
|
|
|
from claire.config import LimitsConfig, load_or_init
|
|
|
|
|
|
def test_load_or_init_creates_file_with_machine_id(tmp_path: Path) -> None:
|
|
cfg_path = tmp_path / "claire.toml"
|
|
cfg = load_or_init(cfg_path)
|
|
assert cfg_path.exists()
|
|
assert len(cfg.machine_id) == 36 # uuid4 string
|
|
# Reload returns the same machine_id (stable across runs).
|
|
cfg2 = load_or_init(cfg_path)
|
|
assert cfg2.machine_id == cfg.machine_id
|
|
|
|
|
|
def test_load_or_init_defaults_web_host_port(tmp_path: Path) -> None:
|
|
cfg = load_or_init(tmp_path / "claire.toml")
|
|
assert cfg.web.host == "127.0.0.1"
|
|
assert cfg.web.port == 8765
|
|
|
|
|
|
def test_load_or_init_migrates_missing_sync_secret(tmp_path: Path) -> None:
|
|
cfg_path = tmp_path / "claire.toml"
|
|
cfg_path.write_text(
|
|
'machine_id = "abc-123"\n'
|
|
"\n"
|
|
"[web]\n"
|
|
'host = "127.0.0.1"\n'
|
|
"port = 8765\n",
|
|
encoding="utf-8",
|
|
)
|
|
cfg = load_or_init(cfg_path)
|
|
assert cfg.sync_secret is not None
|
|
assert len(cfg.sync_secret) > 0
|
|
# File on disk now has the secret.
|
|
on_disk = tomllib.loads(cfg_path.read_text(encoding="utf-8"))
|
|
assert on_disk["sync_secret"] == cfg.sync_secret
|
|
# Idempotent: reload returns the same secret (no re-rolling).
|
|
cfg2 = load_or_init(cfg_path)
|
|
assert cfg2.sync_secret == cfg.sync_secret
|
|
|
|
|
|
def test_load_or_init_migration_preserves_other_fields(tmp_path: Path) -> None:
|
|
cfg_path = tmp_path / "claire.toml"
|
|
cfg_path.write_text(
|
|
'machine_id = "preserved-id"\n'
|
|
"\n"
|
|
"[web]\n"
|
|
'host = "0.0.0.0"\n'
|
|
"port = 9999\n"
|
|
"\n"
|
|
"[[peers]]\n"
|
|
'url = "http://peer-a.local"\n'
|
|
'secret = "peer-a-secret"\n'
|
|
"\n"
|
|
"[[peers]]\n"
|
|
'url = "http://peer-b.local"\n'
|
|
"\n"
|
|
"[[groups]]\n"
|
|
'name = "docs"\n'
|
|
'patterns = ["*.md", "*.txt"]\n',
|
|
encoding="utf-8",
|
|
)
|
|
cfg = load_or_init(cfg_path)
|
|
assert cfg.machine_id == "preserved-id"
|
|
assert cfg.web.host == "0.0.0.0"
|
|
assert cfg.web.port == 9999
|
|
assert len(cfg.peers) == 2
|
|
assert cfg.peers[0].url == "http://peer-a.local"
|
|
assert cfg.peers[0].secret == "peer-a-secret"
|
|
assert cfg.peers[1].url == "http://peer-b.local"
|
|
assert cfg.peers[1].secret is None
|
|
assert len(cfg.groups) == 1
|
|
assert cfg.groups[0].name == "docs"
|
|
assert cfg.groups[0].patterns == ["*.md", "*.txt"]
|
|
assert cfg.sync_secret is not None
|
|
|
|
|
|
# --- per-host session caps -------------------------------------------------
|
|
|
|
|
|
def test_limits_cap_for_resolves_override_else_default() -> None:
|
|
limits = LimitsConfig(per_host_max=3, per_host={"apricot": 8})
|
|
assert limits.cap_for("apricot") == 8 # named override wins
|
|
assert limits.cap_for("plum") == 3 # absent host → default
|
|
|
|
|
|
def test_limits_defaults_to_empty_per_host() -> None:
|
|
limits = LimitsConfig()
|
|
assert limits.per_host == {}
|
|
assert limits.cap_for("any-host") == 3 # the per_host_max default
|
|
|
|
|
|
def test_limits_per_host_rejects_out_of_range() -> None:
|
|
with pytest.raises(ValidationError):
|
|
LimitsConfig(per_host={"apricot": 0})
|
|
with pytest.raises(ValidationError):
|
|
LimitsConfig(per_host={"apricot": 999})
|
|
|
|
|
|
def test_load_or_init_reads_per_host_caps(tmp_path: Path) -> None:
|
|
cfg_path = tmp_path / "claire.toml"
|
|
cfg_path.write_text(
|
|
'machine_id = "m"\n'
|
|
'sync_secret = "s"\n'
|
|
"\n[web]\n"
|
|
'host = "127.0.0.1"\n'
|
|
"port = 8765\n"
|
|
"\n[limits]\n"
|
|
"per_host_max = 4\n"
|
|
'per_host = { "apricot" = 8, "local" = 6 }\n',
|
|
encoding="utf-8",
|
|
)
|
|
cfg = load_or_init(cfg_path)
|
|
assert cfg.limits.per_host_max == 4
|
|
assert cfg.limits.cap_for("apricot") == 8
|
|
assert cfg.limits.cap_for("local") == 6
|
|
assert cfg.limits.cap_for("plum") == 4 # unnamed → default
|
|
|
|
|
|
# --- host detection / known_hosts ------------------------------------------
|
|
|
|
|
|
def test_this_host_label_explicit_overrides_hostname(tmp_path: Path) -> None:
|
|
cfg_path = tmp_path / "claire.toml"
|
|
cfg_path.write_text(
|
|
'machine_id = "m"\n'
|
|
'sync_secret = "s"\n'
|
|
'this_host = "plum"\n'
|
|
'\n[web]\nhost = "127.0.0.1"\nport = 8765\n',
|
|
encoding="utf-8",
|
|
)
|
|
cfg = load_or_init(cfg_path)
|
|
assert cfg.this_host == "plum"
|
|
assert cfg.this_host_label() == "plum"
|
|
|
|
|
|
def test_this_host_label_defaults_to_short_os_hostname(tmp_path: Path) -> None:
|
|
import socket
|
|
cfg = load_or_init(tmp_path / "claire.toml")
|
|
expected = socket.gethostname().split(".", 1)[0].lower()
|
|
assert cfg.this_host is None
|
|
assert cfg.this_host_label() == expected
|
|
|
|
|
|
def test_resolve_host_label_rewrites_local_to_this_host(tmp_path: Path) -> None:
|
|
cfg_path = tmp_path / "claire.toml"
|
|
cfg_path.write_text(
|
|
'machine_id = "m"\n'
|
|
'sync_secret = "s"\n'
|
|
'this_host = "plum"\n'
|
|
'\n[web]\nhost = "127.0.0.1"\nport = 8765\n',
|
|
encoding="utf-8",
|
|
)
|
|
cfg = load_or_init(cfg_path)
|
|
assert cfg.resolve_host_label("local") == "plum"
|
|
# Already-canonical labels pass through.
|
|
assert cfg.resolve_host_label("apricot") == "apricot"
|
|
|
|
|
|
def test_resolve_host_label_uses_aliases(tmp_path: Path) -> None:
|
|
cfg_path = tmp_path / "claire.toml"
|
|
cfg_path.write_text(
|
|
'machine_id = "m"\n'
|
|
'sync_secret = "s"\n'
|
|
'this_host = "plum"\n'
|
|
'\n[web]\nhost = "127.0.0.1"\nport = 8765\n'
|
|
'\n[[known_hosts]]\n'
|
|
'name = "apricot"\n'
|
|
'aliases = ["apri", "apr"]\n',
|
|
encoding="utf-8",
|
|
)
|
|
cfg = load_or_init(cfg_path)
|
|
assert cfg.resolve_host_label("apri") == "apricot"
|
|
assert cfg.resolve_host_label("apr") == "apricot"
|
|
assert cfg.resolve_host_label("apricot") == "apricot"
|
|
# An unrelated label is unchanged.
|
|
assert cfg.resolve_host_label("black") == "black"
|
|
|
|
|
|
def test_serialize_round_trips_host_detection(tmp_path: Path) -> None:
|
|
cfg_path = tmp_path / "claire.toml"
|
|
cfg_path.write_text(
|
|
'machine_id = "m"\n'
|
|
'this_host = "plum"\n'
|
|
'\n[web]\nhost = "127.0.0.1"\nport = 8765\n'
|
|
'\n[[known_hosts]]\n'
|
|
'name = "apricot"\n'
|
|
'aliases = ["apri"]\n'
|
|
'description = "dev box"\n',
|
|
encoding="utf-8",
|
|
)
|
|
load_or_init(cfg_path) # rewrites file (backfills sync_secret)
|
|
reloaded = load_or_init(cfg_path)
|
|
assert reloaded.this_host == "plum"
|
|
assert len(reloaded.known_hosts) == 1
|
|
assert reloaded.known_hosts[0].name == "apricot"
|
|
assert reloaded.known_hosts[0].aliases == ["apri"]
|
|
assert reloaded.known_hosts[0].description == "dev box"
|
|
|
|
|
|
def test_serialize_round_trips_per_host_caps(tmp_path: Path) -> None:
|
|
# A config missing sync_secret triggers the migration rewrite, which
|
|
# runs _serialize — the per_host map must survive the round-trip.
|
|
cfg_path = tmp_path / "claire.toml"
|
|
cfg_path.write_text(
|
|
'machine_id = "m"\n'
|
|
"\n[web]\n"
|
|
'host = "127.0.0.1"\n'
|
|
"port = 8765\n"
|
|
"\n[limits]\n"
|
|
"per_host_max = 5\n"
|
|
'per_host = { "apricot" = 10 }\n',
|
|
encoding="utf-8",
|
|
)
|
|
load_or_init(cfg_path) # rewrites the file (backfills sync_secret)
|
|
reloaded = load_or_init(cfg_path)
|
|
assert reloaded.limits.per_host_max == 5
|
|
assert reloaded.limits.per_host == {"apricot": 10}
|