Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion packages/syft-bg/src/syft_bg/approve/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __init__(
self.interval = config.interval
self._config_path = config_path

self._state = JsonStateManager(config.approve_state_path)
self._state = JsonStateManager(state_file=config.approve_state_path)
self._monitors_initialized = False

def setup(self) -> None:
Expand Down
115 changes: 60 additions & 55 deletions packages/syft-bg/src/syft_bg/common/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,38 @@
from pathlib import Path
from typing import Any, Optional

from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, ValidationError

class JsonStateManager:

class BgState(BaseModel):
"""Typed top-level state for the background services.

``extra="allow"`` preserves the on-disk JSON format for arbitrary top-level
keys written via :meth:`JsonStateManager.set_data` (e.g. ``snapshot``,
``peer_snapshot``, ``email_approve_last_history_id``).
"""

model_config = ConfigDict(extra="allow")

notified_jobs: dict[str, list[str]] = Field(default_factory=dict)
approved_jobs: dict[str, dict[str, str]] = Field(default_factory=dict)
approved_peers: dict[str, dict[str, str]] = Field(default_factory=dict)
thread_ids: dict[str, str] = Field(default_factory=dict)


class JsonStateManager(BaseModel):
"""Manages state persistence with file locking for both notify and approve services."""

def __init__(self, state_file: Path):
self.state_file = Path(state_file).expanduser()
state_file: Path
_lock_file: Path = PrivateAttr()

def model_post_init(self, __context: Any) -> None:
self.state_file = self.state_file.expanduser()
self.state_file.parent.mkdir(parents=True, exist_ok=True)
self._lock_file = self.state_file.with_suffix(".lock")

if not self.state_file.exists():
self._save_all({})
self._save(BgState())

@contextmanager
def _file_lock(self):
Expand All @@ -30,18 +51,17 @@ def _file_lock(self):
finally:
fcntl.flock(lock_handle.fileno(), fcntl.LOCK_UN)

def _load_all(self) -> dict:
def _load(self) -> BgState:
if not self.state_file.exists():
return {}
return BgState()
try:
with open(self.state_file, "r") as f:
return json.load(f)
except (json.JSONDecodeError, OSError, IOError):
return {}
return BgState.model_validate_json(self.state_file.read_text())
except (ValidationError, json.JSONDecodeError, OSError):
return BgState()

def _save_all(self, data: dict):
def _save(self, state: BgState) -> None:
with open(self.state_file, "w") as f:
json.dump(data, f, indent=2)
f.write(state.model_dump_json(indent=2))
self.state_file.chmod(0o600)

def _now_iso(self) -> str:
Expand All @@ -51,86 +71,71 @@ def _now_iso(self) -> str:

def was_notified(self, entity_id: str, event_type: str) -> bool:
"""Check if entity was already notified for event type."""
data = self._load_all()
notified = data.get("notified_jobs", {})
return event_type in notified.get(entity_id, [])
return event_type in self._load().notified_jobs.get(entity_id, [])

def mark_notified(self, entity_id: str, event_type: str) -> None:
"""Mark entity as notified for event type."""
with self._file_lock():
data = self._load_all()
if "notified_jobs" not in data:
data["notified_jobs"] = {}
if entity_id not in data["notified_jobs"]:
data["notified_jobs"][entity_id] = []
if event_type not in data["notified_jobs"][entity_id]:
data["notified_jobs"][entity_id].append(event_type)
self._save_all(data)
state = self._load()
events = state.notified_jobs.setdefault(entity_id, [])
if event_type not in events:
events.append(event_type)
self._save(state)

# --- Approval state (for syft-approve) ---

def was_approved(self, job_name: str) -> bool:
"""Check if job was already approved."""
data = self._load_all()
return job_name in data.get("approved_jobs", {})
return job_name in self._load().approved_jobs

def mark_approved(self, job_name: str, submitted_by: str) -> None:
"""Mark job as approved."""
with self._file_lock():
data = self._load_all()
if "approved_jobs" not in data:
data["approved_jobs"] = {}
data["approved_jobs"][job_name] = {
state = self._load()
state.approved_jobs[job_name] = {
"approved_at": self._now_iso(),
"submitted_by": submitted_by,
}
self._save_all(data)
self._save(state)

def get_approved_jobs(self) -> dict:
"""Get all approved jobs."""
return self._load_all().get("approved_jobs", {})
return self._load().approved_jobs

def was_peer_approved(self, peer_email: str) -> bool:
"""Check if peer was already approved."""
data = self._load_all()
return f"peer_{peer_email}" in data.get("approved_peers", {})
return f"peer_{peer_email}" in self._load().approved_peers

def mark_peer_approved(self, peer_email: str, domain: str) -> None:
"""Mark peer as approved."""
with self._file_lock():
data = self._load_all()
if "approved_peers" not in data:
data["approved_peers"] = {}
data["approved_peers"][f"peer_{peer_email}"] = {
state = self._load()
state.approved_peers[f"peer_{peer_email}"] = {
"approved_at": self._now_iso(),
"domain": domain,
}
self._save_all(data)
self._save(state)

def get_approved_peers(self) -> dict:
"""Get all approved peers."""
return self._load_all().get("approved_peers", {})
return self._load().approved_peers

# --- Email thread tracking ---

def store_thread_id(self, job_name: str, thread_id: str) -> None:
"""Store Gmail thread ID for a job (for threaded notifications)."""
with self._file_lock():
data = self._load_all()
if "thread_ids" not in data:
data["thread_ids"] = {}
data["thread_ids"][job_name] = thread_id
self._save_all(data)
state = self._load()
state.thread_ids[job_name] = thread_id
self._save(state)

def get_thread_id(self, job_name: str) -> Optional[str]:
"""Get stored Gmail thread ID for a job."""
data = self._load_all()
return data.get("thread_ids", {}).get(job_name)
return self._load().thread_ids.get(job_name)

def get_job_name_by_thread_id(self, thread_id: str) -> Optional[str]:
"""Reverse lookup: find job_name for a given Gmail thread ID."""
data = self._load_all()
for job_name, tid in data.get("thread_ids", {}).items():
for job_name, tid in self._load().thread_ids.items():
if tid == thread_id:
return job_name
return None
Expand All @@ -139,21 +144,21 @@ def get_job_name_by_thread_id(self, thread_id: str) -> Optional[str]:

def is_empty(self) -> bool:
"""Check if state has no tracked entities (fresh state)."""
data = self._load_all()
state = self._load()
# Consider empty if no notifications or approvals tracked
has_notified = bool(data.get("notified_jobs"))
has_approved = bool(data.get("approved_jobs") or data.get("approved_peers"))
has_notified = bool(state.notified_jobs)
has_approved = bool(state.approved_jobs or state.approved_peers)
return not (has_notified or has_approved)

# --- Generic data storage ---

def get_data(self, key: str, default: Optional[Any] = None) -> Any:
"""Get arbitrary data by key."""
return self._load_all().get(key, default)
return self._load().model_dump(mode="json").get(key, default)

def set_data(self, key: str, value: Any) -> None:
"""Set arbitrary data by key."""
with self._file_lock():
data = self._load_all()
data[key] = value
self._save_all(data)
state = self._load()
setattr(state, key, value)
self._save(state)
4 changes: 2 additions & 2 deletions packages/syft-bg/src/syft_bg/email_approve/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ def from_config(
credentials = GmailAuth().load_credentials(config.gmail_token_path)

watcher = GmailWatcher(credentials)
state = JsonStateManager(config.email_approve_state_path)
notify_state = JsonStateManager(config.notify_state_path)
state = JsonStateManager(state_file=config.email_approve_state_path)
notify_state = JsonStateManager(state_file=config.notify_state_path)

handler = EmailApproveHandler(
job_client=job_client,
Expand Down
4 changes: 2 additions & 2 deletions packages/syft-bg/src/syft_bg/notify/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def from_config(
credentials = GmailAuth().load_credentials(config.gmail_token_path)
sender = GmailSender(credentials)

state_manager = JsonStateManager(config.notify_state_path)
state_manager = JsonStateManager(state_file=config.notify_state_path)

job_handler = JobHandler(
sender,
Expand All @@ -79,7 +79,7 @@ def from_config(
state=state_manager,
)

sync_state = JsonStateManager(config.sync_state_path)
sync_state = JsonStateManager(state_file=config.sync_state_path)

peer_monitor = PeerMonitor(
do_email=config.do_email,
Expand Down
2 changes: 1 addition & 1 deletion packages/syft-bg/src/syft_bg/sync/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def from_config(cls, config: SyncConfig) -> SyncOrchestrator:
force_ignore_peer_version=config.force_ignore_peer_version,
)

state = JsonStateManager(config.sync_state_path)
state = JsonStateManager(state_file=config.sync_state_path)

return cls(
client=client,
Expand Down
18 changes: 9 additions & 9 deletions tests/unit/syft_bg/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,15 @@ class TestJsonStateManager:
def test_empty_state_file(self, temp_dir):
"""Should handle non-existent state file."""
state_path = temp_dir / "state.json"
state = JsonStateManager(state_path)
state = JsonStateManager(state_file=state_path)

assert not state.was_notified("job1", "new_job")
assert not state.was_approved("job1")

def test_mark_and_check_notified(self, temp_dir):
"""Should track notified jobs by event type."""
state_path = temp_dir / "state.json"
state = JsonStateManager(state_path)
state = JsonStateManager(state_file=state_path)

assert not state.was_notified("job1", "new_job")
state.mark_notified("job1", "new_job")
Expand All @@ -67,7 +67,7 @@ def test_mark_and_check_notified(self, temp_dir):
def test_mark_and_check_approved(self, temp_dir):
"""Should track approved jobs."""
state_path = temp_dir / "state.json"
state = JsonStateManager(state_path)
state = JsonStateManager(state_file=state_path)

assert not state.was_approved("job1")
state.mark_approved("job1", "user@example.com")
Expand All @@ -78,17 +78,17 @@ def test_state_persists_to_file(self, temp_dir):
state_path = temp_dir / "state.json"

# Write state
state1 = JsonStateManager(state_path)
state1 = JsonStateManager(state_file=state_path)
state1.mark_notified("job1", "new_job")

# Read with new instance
state2 = JsonStateManager(state_path)
state2 = JsonStateManager(state_file=state_path)
assert state2.was_notified("job1", "new_job")

def test_state_file_is_valid_json(self, temp_dir):
"""State file should be valid JSON."""
state_path = temp_dir / "state.json"
state = JsonStateManager(state_path)
state = JsonStateManager(state_file=state_path)
state.mark_notified("job1", "new_job")

# Should be parseable as JSON
Expand All @@ -100,7 +100,7 @@ def test_state_file_is_valid_json(self, temp_dir):
def test_get_and_set_data(self, temp_dir):
"""Should support generic get/set operations."""
state_path = temp_dir / "state.json"
state = JsonStateManager(state_path)
state = JsonStateManager(state_file=state_path)

assert state.get_data("custom_key") is None
state.set_data("custom_key", {"foo": "bar"})
Expand All @@ -109,14 +109,14 @@ def test_get_and_set_data(self, temp_dir):
def test_store_and_get_thread_id(self, temp_dir):
"""Should store a thread_id for a job and retrieve it."""
state_path = temp_dir / "state.json"
state = JsonStateManager(state_path)
state = JsonStateManager(state_file=state_path)

state.store_thread_id("job1", "thread-abc-123")
assert state.get_thread_id("job1") == "thread-abc-123"

def test_get_thread_id_nonexistent(self, temp_dir):
"""Should return None for an unknown job's thread_id."""
state_path = temp_dir / "state.json"
state = JsonStateManager(state_path)
state = JsonStateManager(state_file=state_path)

assert state.get_thread_id("nonexistent_job") is None
4 changes: 2 additions & 2 deletions tests/unit/syft_bg/test_email_approval_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def _make_notify_orchestrator(

Returns (orchestrator, notify_state, mock_sender).
"""
notify_state = JsonStateManager(tmp / "notify_state.json")
notify_state = JsonStateManager(state_file=tmp / "notify_state.json")

mock_sender = MagicMock()
mock_sender.notify_new_job.return_value = SendResult(
Expand Down Expand Up @@ -74,7 +74,7 @@ def _make_email_approve_orchestrator(
"""
do_email = do_manager.email

email_approve_state = JsonStateManager(tmp / "email_approve_state.json")
email_approve_state = JsonStateManager(state_file=tmp / "email_approve_state.json")
email_approve_state.set_data("email_approve_last_history_id", "10000")

handler = EmailApproveHandler(
Expand Down
6 changes: 3 additions & 3 deletions tests/unit/syft_bg/test_email_approve.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,8 @@ def test_forwarded_message(self):

class TestEmailApproveHandler:
def _make_handler(self, tmp_path):
state = JsonStateManager(tmp_path / "email_approve_state.json")
notify_state = JsonStateManager(tmp_path / "notify_state.json")
state = JsonStateManager(state_file=tmp_path / "email_approve_state.json")
notify_state = JsonStateManager(state_file=tmp_path / "notify_state.json")
job_client = MagicMock()
job_runner = MagicMock()
handler = EmailApproveHandler(
Expand Down Expand Up @@ -262,7 +262,7 @@ def test_job_not_pending(self, tmp_path):

class TestStateReverseLookup:
def test_get_job_name_by_thread_id(self, tmp_path):
state = JsonStateManager(tmp_path / "state.json")
state = JsonStateManager(state_file=tmp_path / "state.json")
state.store_thread_id("job_a", "thread_1")
state.store_thread_id("job_b", "thread_2")

Expand Down
4 changes: 2 additions & 2 deletions tests/unit/syft_bg/test_email_auto_approve_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def _make_notify_orchestrator(
tmp: Path,
) -> tuple[NotificationOrchestrator, JsonStateManager, MagicMock]:
"""Create a NotificationOrchestrator with a mocked GmailSender."""
notify_state = JsonStateManager(tmp / "notify_state.json")
notify_state = JsonStateManager(state_file=tmp / "notify_state.json")

mock_sender = MagicMock()
mock_sender.notify_new_job.return_value = SendResult(
Expand Down Expand Up @@ -89,7 +89,7 @@ def _make_email_approve_orchestrator(
"""Create email approve components with mocked GmailWatcher."""
do_email = do_manager.email

email_approve_state = JsonStateManager(tmp / "email_approve_state.json")
email_approve_state = JsonStateManager(state_file=tmp / "email_approve_state.json")
email_approve_state.set_data("email_approve_last_history_id", "10000")

handler = EmailApproveHandler(
Expand Down
Loading
Loading