Coverage for pydantic_graph/pydantic_graph/persistence/file.py: 99.15%

108 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-28 17:27 +0000

1from __future__ import annotations as _annotations 

2 

3import asyncio 

4import secrets 

5from collections.abc import AsyncIterator 

6from contextlib import AsyncExitStack, asynccontextmanager 

7from dataclasses import dataclass, field 

8from pathlib import Path 

9from time import perf_counter 

10from typing import Any 

11 

12import pydantic 

13 

14from .. import _utils as _graph_utils, exceptions 

15from ..nodes import BaseNode, End 

16from . import ( 

17 BaseStatePersistence, 

18 EndSnapshot, 

19 NodeSnapshot, 

20 RunEndT, 

21 Snapshot, 

22 SnapshotStatus, 

23 StateT, 

24 _utils, 

25 build_snapshot_list_type_adapter, 

26) 

27 

28 

29@dataclass 

30class FileStatePersistence(BaseStatePersistence[StateT, RunEndT]): 

31 """File based state persistence that hold graph run state in a JSON file.""" 

32 

33 json_file: Path 

34 """Path to the JSON file where the snapshots are stored. 

35 

36 You should use a different file for each graph run, but a single file should be reused for multiple 

37 steps of the same run. 

38 

39 For example if you have a run ID of the form `run_123abc`, you might create a `FileStatePersistence` thus: 

40 

41 ```py 

42 from pathlib import Path 

43 

44 from pydantic_graph import FullStatePersistence 

45 

46 run_id = 'run_123abc' 

47 persistence = FullStatePersistence(Path('runs') / f'{run_id}.json') 

48 ``` 

49 """ 

50 _snapshots_type_adapter: pydantic.TypeAdapter[list[Snapshot[StateT, RunEndT]]] | None = field( 

51 default=None, init=False, repr=False 

52 ) 

53 

54 async def snapshot_node(self, state: StateT, next_node: BaseNode[StateT, Any, RunEndT]) -> None: 

55 await self._append_save(NodeSnapshot(state=state, node=next_node)) 

56 

57 async def snapshot_node_if_new( 

58 self, snapshot_id: str, state: StateT, next_node: BaseNode[StateT, Any, RunEndT] 

59 ) -> None: 

60 async with self._lock(): 

61 snapshots = await self.load_all() 

62 if not any(s.id == snapshot_id for s in snapshots): 62 ↛ exitline 62 didn't jump to the function exit

63 await self._append_save(NodeSnapshot(state=state, node=next_node), lock=False) 

64 

65 async def snapshot_end(self, state: StateT, end: End[RunEndT]) -> None: 

66 await self._append_save(EndSnapshot(state=state, result=end)) 

67 

68 @asynccontextmanager 

69 async def record_run(self, snapshot_id: str) -> AsyncIterator[None]: 

70 async with self._lock(): 

71 snapshots = await self.load_all() 

72 try: 

73 snapshot = next(s for s in snapshots if s.id == snapshot_id) 

74 except StopIteration as e: 

75 raise LookupError(f'No snapshot found with id={snapshot_id!r}') from e 

76 

77 assert isinstance(snapshot, NodeSnapshot), 'Only NodeSnapshot can be recorded' 

78 exceptions.GraphNodeStatusError.check(snapshot.status) 

79 snapshot.status = 'running' 

80 snapshot.start_ts = _utils.now_utc() 

81 await self._save(snapshots) 

82 

83 start = perf_counter() 

84 try: 

85 yield 

86 except Exception: 

87 duration = perf_counter() - start 

88 async with self._lock(): 

89 await _graph_utils.run_in_executor(self._after_run_sync, snapshot_id, duration, 'error') 

90 raise 

91 else: 

92 snapshot.duration = perf_counter() - start 

93 async with self._lock(): 

94 await _graph_utils.run_in_executor(self._after_run_sync, snapshot_id, snapshot.duration, 'success') 

95 

96 async def load_next(self) -> NodeSnapshot[StateT, RunEndT] | None: 

97 async with self._lock(): 

98 snapshots = await self.load_all() 

99 if snapshot := next((s for s in snapshots if isinstance(s, NodeSnapshot) and s.status == 'created'), None): 

100 snapshot.status = 'pending' 

101 await self._save(snapshots) 

102 return snapshot 

103 

104 def should_set_types(self) -> bool: 

105 """Whether types need to be set.""" 

106 return self._snapshots_type_adapter is None 

107 

108 def set_types(self, state_type: type[StateT], run_end_type: type[RunEndT]) -> None: 

109 self._snapshots_type_adapter = build_snapshot_list_type_adapter(state_type, run_end_type) 

110 

111 async def load_all(self) -> list[Snapshot[StateT, RunEndT]]: 

112 return await _graph_utils.run_in_executor(self._load_sync) 

113 

114 def _load_sync(self) -> list[Snapshot[StateT, RunEndT]]: 

115 assert self._snapshots_type_adapter is not None, 'snapshots type adapter must be set' 

116 try: 

117 content = self.json_file.read_bytes() 

118 except FileNotFoundError: 

119 return [] 

120 else: 

121 return self._snapshots_type_adapter.validate_json(content) 

122 

123 def _after_run_sync(self, snapshot_id: str, duration: float, status: SnapshotStatus) -> None: 

124 snapshots = self._load_sync() 

125 snapshot = next(s for s in snapshots if s.id == snapshot_id) 

126 assert isinstance(snapshot, NodeSnapshot), 'Only NodeSnapshot can be recorded' 

127 snapshot.duration = duration 

128 snapshot.status = status 

129 self._save_sync(snapshots) 

130 

131 async def _save(self, snapshots: list[Snapshot[StateT, RunEndT]]) -> None: 

132 await _graph_utils.run_in_executor(self._save_sync, snapshots) 

133 

134 def _save_sync(self, snapshots: list[Snapshot[StateT, RunEndT]]) -> None: 

135 assert self._snapshots_type_adapter is not None, 'snapshots type adapter must be set' 

136 self.json_file.write_bytes(self._snapshots_type_adapter.dump_json(snapshots, indent=2)) 

137 

138 async def _append_save(self, snapshot: Snapshot[StateT, RunEndT], *, lock: bool = True) -> None: 

139 assert self._snapshots_type_adapter is not None, 'snapshots type adapter must be set' 

140 async with AsyncExitStack() as stack: 

141 if lock: 

142 await stack.enter_async_context(self._lock()) 

143 snapshots = await self.load_all() 

144 snapshots.append(snapshot) 

145 await self._save(snapshots) 

146 

147 @asynccontextmanager 

148 async def _lock(self, *, timeout: float = 1.0) -> AsyncIterator[None]: 

149 """Lock a file by checking and writing a `.pydantic-graph-persistence-lock` to it. 

150 

151 Args: 

152 timeout: how long to wait for the lock 

153 

154 Returns: an async context manager that holds the lock 

155 """ 

156 lock_file = self.json_file.parent / f'{self.json_file.name}.pydantic-graph-persistence-lock' 

157 lock_id = secrets.token_urlsafe().encode() 

158 await asyncio.wait_for(_get_lock(lock_file, lock_id), timeout=timeout) 

159 try: 

160 yield 

161 finally: 

162 await _graph_utils.run_in_executor(lock_file.unlink, missing_ok=True) 

163 

164 

165async def _get_lock(lock_file: Path, lock_id: bytes): 

166 # TODO replace with inline code and `asyncio.timeout` when we drop 3.9 

167 while not await _graph_utils.run_in_executor(_file_append_check, lock_file, lock_id): 

168 await asyncio.sleep(0.01) 

169 

170 

171def _file_append_check(file: Path, content: bytes) -> bool: 

172 if file.exists(): 

173 return False 

174 

175 with file.open(mode='ab') as f: 

176 f.write(content + b'\n') 

177 

178 return file.read_bytes().startswith(content)