Coverage for pydantic_graph/pydantic_graph/persistence/file.py: 99.15%
108 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-28 17:27 +0000
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-28 17:27 +0000
1from __future__ import annotations as _annotations
3import asyncio
4import secrets
5from collections.abc import AsyncIterator
6from contextlib import AsyncExitStack, asynccontextmanager
7from dataclasses import dataclass, field
8from pathlib import Path
9from time import perf_counter
10from typing import Any
12import pydantic
14from .. import _utils as _graph_utils, exceptions
15from ..nodes import BaseNode, End
16from . import (
17 BaseStatePersistence,
18 EndSnapshot,
19 NodeSnapshot,
20 RunEndT,
21 Snapshot,
22 SnapshotStatus,
23 StateT,
24 _utils,
25 build_snapshot_list_type_adapter,
26)
29@dataclass
30class FileStatePersistence(BaseStatePersistence[StateT, RunEndT]):
31 """File based state persistence that hold graph run state in a JSON file."""
33 json_file: Path
34 """Path to the JSON file where the snapshots are stored.
36 You should use a different file for each graph run, but a single file should be reused for multiple
37 steps of the same run.
39 For example if you have a run ID of the form `run_123abc`, you might create a `FileStatePersistence` thus:
41 ```py
42 from pathlib import Path
44 from pydantic_graph import FullStatePersistence
46 run_id = 'run_123abc'
47 persistence = FullStatePersistence(Path('runs') / f'{run_id}.json')
48 ```
49 """
50 _snapshots_type_adapter: pydantic.TypeAdapter[list[Snapshot[StateT, RunEndT]]] | None = field(
51 default=None, init=False, repr=False
52 )
54 async def snapshot_node(self, state: StateT, next_node: BaseNode[StateT, Any, RunEndT]) -> None:
55 await self._append_save(NodeSnapshot(state=state, node=next_node))
57 async def snapshot_node_if_new(
58 self, snapshot_id: str, state: StateT, next_node: BaseNode[StateT, Any, RunEndT]
59 ) -> None:
60 async with self._lock():
61 snapshots = await self.load_all()
62 if not any(s.id == snapshot_id for s in snapshots): 62 ↛ exitline 62 didn't jump to the function exit
63 await self._append_save(NodeSnapshot(state=state, node=next_node), lock=False)
65 async def snapshot_end(self, state: StateT, end: End[RunEndT]) -> None:
66 await self._append_save(EndSnapshot(state=state, result=end))
68 @asynccontextmanager
69 async def record_run(self, snapshot_id: str) -> AsyncIterator[None]:
70 async with self._lock():
71 snapshots = await self.load_all()
72 try:
73 snapshot = next(s for s in snapshots if s.id == snapshot_id)
74 except StopIteration as e:
75 raise LookupError(f'No snapshot found with id={snapshot_id!r}') from e
77 assert isinstance(snapshot, NodeSnapshot), 'Only NodeSnapshot can be recorded'
78 exceptions.GraphNodeStatusError.check(snapshot.status)
79 snapshot.status = 'running'
80 snapshot.start_ts = _utils.now_utc()
81 await self._save(snapshots)
83 start = perf_counter()
84 try:
85 yield
86 except Exception:
87 duration = perf_counter() - start
88 async with self._lock():
89 await _graph_utils.run_in_executor(self._after_run_sync, snapshot_id, duration, 'error')
90 raise
91 else:
92 snapshot.duration = perf_counter() - start
93 async with self._lock():
94 await _graph_utils.run_in_executor(self._after_run_sync, snapshot_id, snapshot.duration, 'success')
96 async def load_next(self) -> NodeSnapshot[StateT, RunEndT] | None:
97 async with self._lock():
98 snapshots = await self.load_all()
99 if snapshot := next((s for s in snapshots if isinstance(s, NodeSnapshot) and s.status == 'created'), None):
100 snapshot.status = 'pending'
101 await self._save(snapshots)
102 return snapshot
104 def should_set_types(self) -> bool:
105 """Whether types need to be set."""
106 return self._snapshots_type_adapter is None
108 def set_types(self, state_type: type[StateT], run_end_type: type[RunEndT]) -> None:
109 self._snapshots_type_adapter = build_snapshot_list_type_adapter(state_type, run_end_type)
111 async def load_all(self) -> list[Snapshot[StateT, RunEndT]]:
112 return await _graph_utils.run_in_executor(self._load_sync)
114 def _load_sync(self) -> list[Snapshot[StateT, RunEndT]]:
115 assert self._snapshots_type_adapter is not None, 'snapshots type adapter must be set'
116 try:
117 content = self.json_file.read_bytes()
118 except FileNotFoundError:
119 return []
120 else:
121 return self._snapshots_type_adapter.validate_json(content)
123 def _after_run_sync(self, snapshot_id: str, duration: float, status: SnapshotStatus) -> None:
124 snapshots = self._load_sync()
125 snapshot = next(s for s in snapshots if s.id == snapshot_id)
126 assert isinstance(snapshot, NodeSnapshot), 'Only NodeSnapshot can be recorded'
127 snapshot.duration = duration
128 snapshot.status = status
129 self._save_sync(snapshots)
131 async def _save(self, snapshots: list[Snapshot[StateT, RunEndT]]) -> None:
132 await _graph_utils.run_in_executor(self._save_sync, snapshots)
134 def _save_sync(self, snapshots: list[Snapshot[StateT, RunEndT]]) -> None:
135 assert self._snapshots_type_adapter is not None, 'snapshots type adapter must be set'
136 self.json_file.write_bytes(self._snapshots_type_adapter.dump_json(snapshots, indent=2))
138 async def _append_save(self, snapshot: Snapshot[StateT, RunEndT], *, lock: bool = True) -> None:
139 assert self._snapshots_type_adapter is not None, 'snapshots type adapter must be set'
140 async with AsyncExitStack() as stack:
141 if lock:
142 await stack.enter_async_context(self._lock())
143 snapshots = await self.load_all()
144 snapshots.append(snapshot)
145 await self._save(snapshots)
147 @asynccontextmanager
148 async def _lock(self, *, timeout: float = 1.0) -> AsyncIterator[None]:
149 """Lock a file by checking and writing a `.pydantic-graph-persistence-lock` to it.
151 Args:
152 timeout: how long to wait for the lock
154 Returns: an async context manager that holds the lock
155 """
156 lock_file = self.json_file.parent / f'{self.json_file.name}.pydantic-graph-persistence-lock'
157 lock_id = secrets.token_urlsafe().encode()
158 await asyncio.wait_for(_get_lock(lock_file, lock_id), timeout=timeout)
159 try:
160 yield
161 finally:
162 await _graph_utils.run_in_executor(lock_file.unlink, missing_ok=True)
165async def _get_lock(lock_file: Path, lock_id: bytes):
166 # TODO replace with inline code and `asyncio.timeout` when we drop 3.9
167 while not await _graph_utils.run_in_executor(_file_append_check, lock_file, lock_id):
168 await asyncio.sleep(0.01)
171def _file_append_check(file: Path, content: bytes) -> bool:
172 if file.exists():
173 return False
175 with file.open(mode='ab') as f:
176 f.write(content + b'\n')
178 return file.read_bytes().startswith(content)