Coverage for pydantic_graph/pydantic_graph/persistence/in_mem.py: 98.21%

100 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-28 17:27 +0000

1"""In memory state persistence. 

2 

3This module provides simple in memory state persistence for graphs. 

4""" 

5 

6from __future__ import annotations as _annotations 

7 

8import copy 

9from collections.abc import AsyncIterator 

10from contextlib import asynccontextmanager 

11from dataclasses import dataclass, field 

12from time import perf_counter 

13from typing import Any 

14 

15import pydantic 

16 

17from .. import exceptions 

18from ..nodes import BaseNode, End 

19from . import ( 

20 BaseStatePersistence, 

21 EndSnapshot, 

22 NodeSnapshot, 

23 RunEndT, 

24 Snapshot, 

25 StateT, 

26 _utils, 

27 build_snapshot_list_type_adapter, 

28) 

29 

30 

31@dataclass 

32class SimpleStatePersistence(BaseStatePersistence[StateT, RunEndT]): 

33 """Simple in memory state persistence that just hold the latest snapshot. 

34 

35 If no state persistence implementation is provided when running a graph, this is used by default. 

36 """ 

37 

38 last_snapshot: Snapshot[StateT, RunEndT] | None = None 

39 """The last snapshot.""" 

40 

41 async def snapshot_node(self, state: StateT, next_node: BaseNode[StateT, Any, RunEndT]) -> None: 

42 self.last_snapshot = NodeSnapshot(state=state, node=next_node) 

43 

44 async def snapshot_node_if_new( 

45 self, snapshot_id: str, state: StateT, next_node: BaseNode[StateT, Any, RunEndT] 

46 ) -> None: 

47 if self.last_snapshot and self.last_snapshot.id == snapshot_id: 47 ↛ 48line 47 didn't jump to line 48 because the condition on line 47 was never true

48 return 

49 else: 

50 await self.snapshot_node(state, next_node) 

51 

52 async def snapshot_end(self, state: StateT, end: End[RunEndT]) -> None: 

53 self.last_snapshot = EndSnapshot(state=state, result=end) 

54 

55 @asynccontextmanager 

56 async def record_run(self, snapshot_id: str) -> AsyncIterator[None]: 

57 if self.last_snapshot is None or snapshot_id != self.last_snapshot.id: 

58 raise LookupError(f'No snapshot found with id={snapshot_id!r}') 

59 

60 assert isinstance(self.last_snapshot, NodeSnapshot), 'Only NodeSnapshot can be recorded' 

61 exceptions.GraphNodeStatusError.check(self.last_snapshot.status) 

62 self.last_snapshot.status = 'running' 

63 self.last_snapshot.start_ts = _utils.now_utc() 

64 

65 start = perf_counter() 

66 try: 

67 yield 

68 except Exception: 

69 self.last_snapshot.duration = perf_counter() - start 

70 self.last_snapshot.status = 'error' 

71 raise 

72 else: 

73 self.last_snapshot.duration = perf_counter() - start 

74 self.last_snapshot.status = 'success' 

75 

76 async def load_next(self) -> NodeSnapshot[StateT, RunEndT] | None: 

77 if isinstance(self.last_snapshot, NodeSnapshot) and self.last_snapshot.status == 'created': 

78 self.last_snapshot.status = 'pending' 

79 return self.last_snapshot 

80 

81 async def load_all(self) -> list[Snapshot[StateT, RunEndT]]: 

82 raise NotImplementedError('load is not supported for SimpleStatePersistence') 

83 

84 

85@dataclass 

86class FullStatePersistence(BaseStatePersistence[StateT, RunEndT]): 

87 """In memory state persistence that hold a list of snapshots.""" 

88 

89 deep_copy: bool = True 

90 """Whether to deep copy the state and nodes when storing them. 

91 

92 Defaults to `True` so even if nodes or state are modified after the snapshot is taken, 

93 the persistence history will record the value at the time of the snapshot. 

94 """ 

95 history: list[Snapshot[StateT, RunEndT]] = field(default_factory=list) 

96 """List of snapshots taken during the graph run.""" 

97 _snapshots_type_adapter: pydantic.TypeAdapter[list[Snapshot[StateT, RunEndT]]] | None = field( 

98 default=None, init=False, repr=False 

99 ) 

100 

101 async def snapshot_node(self, state: StateT, next_node: BaseNode[StateT, Any, RunEndT]) -> None: 

102 snapshot = NodeSnapshot( 

103 state=self._prep_state(state), 

104 node=next_node.deep_copy() if self.deep_copy else next_node, 

105 ) 

106 self.history.append(snapshot) 

107 

108 async def snapshot_node_if_new( 

109 self, snapshot_id: str, state: StateT, next_node: BaseNode[StateT, Any, RunEndT] 

110 ) -> None: 

111 if not any(s.id == snapshot_id for s in self.history): 

112 await self.snapshot_node(state, next_node) 

113 

114 async def snapshot_end(self, state: StateT, end: End[RunEndT]) -> None: 

115 snapshot = EndSnapshot( 

116 state=self._prep_state(state), 

117 result=end.deep_copy_data() if self.deep_copy else end, 

118 ) 

119 self.history.append(snapshot) 

120 

121 @asynccontextmanager 

122 async def record_run(self, snapshot_id: str) -> AsyncIterator[None]: 

123 try: 

124 snapshot = next(s for s in self.history if s.id == snapshot_id) 

125 except StopIteration as e: 

126 raise LookupError(f'No snapshot found with id={snapshot_id!r}') from e 

127 

128 assert isinstance(snapshot, NodeSnapshot), 'Only NodeSnapshot can be recorded' 

129 exceptions.GraphNodeStatusError.check(snapshot.status) 

130 snapshot.status = 'running' 

131 snapshot.start_ts = _utils.now_utc() 

132 start = perf_counter() 

133 try: 

134 yield 

135 except Exception: 

136 snapshot.duration = perf_counter() - start 

137 snapshot.status = 'error' 

138 raise 

139 else: 

140 snapshot.duration = perf_counter() - start 

141 snapshot.status = 'success' 

142 

143 async def load_next(self) -> NodeSnapshot[StateT, RunEndT] | None: 

144 if snapshot := next((s for s in self.history if isinstance(s, NodeSnapshot) and s.status == 'created'), None): 

145 snapshot.status = 'pending' 

146 return snapshot 

147 

148 async def load_all(self) -> list[Snapshot[StateT, RunEndT]]: 

149 return self.history 

150 

151 def should_set_types(self) -> bool: 

152 return self._snapshots_type_adapter is None 

153 

154 def set_types(self, state_type: type[StateT], run_end_type: type[RunEndT]) -> None: 

155 self._snapshots_type_adapter = build_snapshot_list_type_adapter(state_type, run_end_type) 

156 

157 def dump_json(self, *, indent: int | None = None) -> bytes: 

158 """Dump the history to JSON bytes.""" 

159 assert self._snapshots_type_adapter is not None, 'type adapter must be set to use `dump_json`' 

160 return self._snapshots_type_adapter.dump_json(self.history, indent=indent) 

161 

162 def load_json(self, json_data: str | bytes | bytearray) -> None: 

163 """Load the history from JSON.""" 

164 assert self._snapshots_type_adapter is not None, 'type adapter must be set to use `load_json`' 

165 self.history = self._snapshots_type_adapter.validate_json(json_data) 

166 

167 def _prep_state(self, state: StateT) -> StateT: 

168 """Prepare state for snapshot, uses [`copy.deepcopy`][copy.deepcopy] by default.""" 

169 if not self.deep_copy or state is None: 

170 return state 

171 else: 

172 return copy.deepcopy(state)