Coverage for pydantic_graph/pydantic_graph/persistence/in_mem.py: 98.21%
100 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-28 17:27 +0000
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-28 17:27 +0000
1"""In memory state persistence.
3This module provides simple in memory state persistence for graphs.
4"""
6from __future__ import annotations as _annotations
8import copy
9from collections.abc import AsyncIterator
10from contextlib import asynccontextmanager
11from dataclasses import dataclass, field
12from time import perf_counter
13from typing import Any
15import pydantic
17from .. import exceptions
18from ..nodes import BaseNode, End
19from . import (
20 BaseStatePersistence,
21 EndSnapshot,
22 NodeSnapshot,
23 RunEndT,
24 Snapshot,
25 StateT,
26 _utils,
27 build_snapshot_list_type_adapter,
28)
31@dataclass
32class SimpleStatePersistence(BaseStatePersistence[StateT, RunEndT]):
33 """Simple in memory state persistence that just hold the latest snapshot.
35 If no state persistence implementation is provided when running a graph, this is used by default.
36 """
38 last_snapshot: Snapshot[StateT, RunEndT] | None = None
39 """The last snapshot."""
41 async def snapshot_node(self, state: StateT, next_node: BaseNode[StateT, Any, RunEndT]) -> None:
42 self.last_snapshot = NodeSnapshot(state=state, node=next_node)
44 async def snapshot_node_if_new(
45 self, snapshot_id: str, state: StateT, next_node: BaseNode[StateT, Any, RunEndT]
46 ) -> None:
47 if self.last_snapshot and self.last_snapshot.id == snapshot_id: 47 ↛ 48line 47 didn't jump to line 48 because the condition on line 47 was never true
48 return
49 else:
50 await self.snapshot_node(state, next_node)
52 async def snapshot_end(self, state: StateT, end: End[RunEndT]) -> None:
53 self.last_snapshot = EndSnapshot(state=state, result=end)
55 @asynccontextmanager
56 async def record_run(self, snapshot_id: str) -> AsyncIterator[None]:
57 if self.last_snapshot is None or snapshot_id != self.last_snapshot.id:
58 raise LookupError(f'No snapshot found with id={snapshot_id!r}')
60 assert isinstance(self.last_snapshot, NodeSnapshot), 'Only NodeSnapshot can be recorded'
61 exceptions.GraphNodeStatusError.check(self.last_snapshot.status)
62 self.last_snapshot.status = 'running'
63 self.last_snapshot.start_ts = _utils.now_utc()
65 start = perf_counter()
66 try:
67 yield
68 except Exception:
69 self.last_snapshot.duration = perf_counter() - start
70 self.last_snapshot.status = 'error'
71 raise
72 else:
73 self.last_snapshot.duration = perf_counter() - start
74 self.last_snapshot.status = 'success'
76 async def load_next(self) -> NodeSnapshot[StateT, RunEndT] | None:
77 if isinstance(self.last_snapshot, NodeSnapshot) and self.last_snapshot.status == 'created':
78 self.last_snapshot.status = 'pending'
79 return self.last_snapshot
81 async def load_all(self) -> list[Snapshot[StateT, RunEndT]]:
82 raise NotImplementedError('load is not supported for SimpleStatePersistence')
85@dataclass
86class FullStatePersistence(BaseStatePersistence[StateT, RunEndT]):
87 """In memory state persistence that hold a list of snapshots."""
89 deep_copy: bool = True
90 """Whether to deep copy the state and nodes when storing them.
92 Defaults to `True` so even if nodes or state are modified after the snapshot is taken,
93 the persistence history will record the value at the time of the snapshot.
94 """
95 history: list[Snapshot[StateT, RunEndT]] = field(default_factory=list)
96 """List of snapshots taken during the graph run."""
97 _snapshots_type_adapter: pydantic.TypeAdapter[list[Snapshot[StateT, RunEndT]]] | None = field(
98 default=None, init=False, repr=False
99 )
101 async def snapshot_node(self, state: StateT, next_node: BaseNode[StateT, Any, RunEndT]) -> None:
102 snapshot = NodeSnapshot(
103 state=self._prep_state(state),
104 node=next_node.deep_copy() if self.deep_copy else next_node,
105 )
106 self.history.append(snapshot)
108 async def snapshot_node_if_new(
109 self, snapshot_id: str, state: StateT, next_node: BaseNode[StateT, Any, RunEndT]
110 ) -> None:
111 if not any(s.id == snapshot_id for s in self.history):
112 await self.snapshot_node(state, next_node)
114 async def snapshot_end(self, state: StateT, end: End[RunEndT]) -> None:
115 snapshot = EndSnapshot(
116 state=self._prep_state(state),
117 result=end.deep_copy_data() if self.deep_copy else end,
118 )
119 self.history.append(snapshot)
121 @asynccontextmanager
122 async def record_run(self, snapshot_id: str) -> AsyncIterator[None]:
123 try:
124 snapshot = next(s for s in self.history if s.id == snapshot_id)
125 except StopIteration as e:
126 raise LookupError(f'No snapshot found with id={snapshot_id!r}') from e
128 assert isinstance(snapshot, NodeSnapshot), 'Only NodeSnapshot can be recorded'
129 exceptions.GraphNodeStatusError.check(snapshot.status)
130 snapshot.status = 'running'
131 snapshot.start_ts = _utils.now_utc()
132 start = perf_counter()
133 try:
134 yield
135 except Exception:
136 snapshot.duration = perf_counter() - start
137 snapshot.status = 'error'
138 raise
139 else:
140 snapshot.duration = perf_counter() - start
141 snapshot.status = 'success'
143 async def load_next(self) -> NodeSnapshot[StateT, RunEndT] | None:
144 if snapshot := next((s for s in self.history if isinstance(s, NodeSnapshot) and s.status == 'created'), None):
145 snapshot.status = 'pending'
146 return snapshot
148 async def load_all(self) -> list[Snapshot[StateT, RunEndT]]:
149 return self.history
151 def should_set_types(self) -> bool:
152 return self._snapshots_type_adapter is None
154 def set_types(self, state_type: type[StateT], run_end_type: type[RunEndT]) -> None:
155 self._snapshots_type_adapter = build_snapshot_list_type_adapter(state_type, run_end_type)
157 def dump_json(self, *, indent: int | None = None) -> bytes:
158 """Dump the history to JSON bytes."""
159 assert self._snapshots_type_adapter is not None, 'type adapter must be set to use `dump_json`'
160 return self._snapshots_type_adapter.dump_json(self.history, indent=indent)
162 def load_json(self, json_data: str | bytes | bytearray) -> None:
163 """Load the history from JSON."""
164 assert self._snapshots_type_adapter is not None, 'type adapter must be set to use `load_json`'
165 self.history = self._snapshots_type_adapter.validate_json(json_data)
167 def _prep_state(self, state: StateT) -> StateT:
168 """Prepare state for snapshot, uses [`copy.deepcopy`][copy.deepcopy] by default."""
169 if not self.deep_copy or state is None:
170 return state
171 else:
172 return copy.deepcopy(state)