Coverage for pydantic_graph/pydantic_graph/state.py: 97.33%
71 statements
« prev ^ index » next coverage.py v7.6.7, created at 2025-01-25 16:43 +0000
« prev ^ index » next coverage.py v7.6.7, created at 2025-01-25 16:43 +0000
1from __future__ import annotations as _annotations
3import copy
4from collections.abc import Sequence
5from contextvars import ContextVar
6from dataclasses import dataclass, field
7from datetime import datetime
8from typing import Annotated, Any, Callable, Generic, Literal, Union
10import pydantic
11from pydantic_core import core_schema
12from typing_extensions import TypeVar
14from . import _utils
15from .nodes import BaseNode, End, RunEndT
17__all__ = 'StateT', 'NodeStep', 'EndStep', 'HistoryStep', 'deep_copy_state', 'nodes_schema_var'
20StateT = TypeVar('StateT', default=None)
21"""Type variable for the state in a graph."""
24def deep_copy_state(state: StateT) -> StateT:
25 """Default method for snapshotting the state in a graph run, uses [`copy.deepcopy`][copy.deepcopy]."""
26 if state is None:
27 return state
28 else:
29 return copy.deepcopy(state)
32@dataclass
33class NodeStep(Generic[StateT, RunEndT]):
34 """History step describing the execution of a node in a graph."""
36 state: StateT
37 """The state of the graph after the node has been run."""
38 node: Annotated[BaseNode[StateT, Any, RunEndT], CustomNodeSchema()]
39 """The node that was run."""
40 start_ts: datetime = field(default_factory=_utils.now_utc)
41 """The timestamp when the node started running."""
42 duration: float | None = None
43 """The duration of the node run in seconds."""
44 kind: Literal['node'] = 'node'
45 """The kind of history step, can be used as a discriminator when deserializing history."""
46 # TODO waiting for https://github.com/pydantic/pydantic/issues/11264, should be an InitVar
47 snapshot_state: Annotated[Callable[[StateT], StateT], pydantic.Field(exclude=True, repr=False)] = field(
48 default=deep_copy_state, repr=False
49 )
50 """Function to snapshot the state of the graph."""
52 def __post_init__(self):
53 # Copy the state to prevent it from being modified by other code
54 self.state = self.snapshot_state(self.state)
56 def data_snapshot(self) -> BaseNode[StateT, Any, RunEndT]:
57 """Returns a deep copy of [`self.node`][pydantic_graph.state.NodeStep.node].
59 Useful for summarizing history.
60 """
61 return copy.deepcopy(self.node)
64@dataclass
65class EndStep(Generic[RunEndT]):
66 """History step describing the end of a graph run."""
68 result: End[RunEndT]
69 """The result of the graph run."""
70 ts: datetime = field(default_factory=_utils.now_utc)
71 """The timestamp when the graph run ended."""
72 kind: Literal['end'] = 'end'
73 """The kind of history step, can be used as a discriminator when deserializing history."""
75 def data_snapshot(self) -> End[RunEndT]:
76 """Returns a deep copy of [`self.result`][pydantic_graph.state.EndStep.result].
78 Useful for summarizing history.
79 """
80 return copy.deepcopy(self.result)
83HistoryStep = Union[NodeStep[StateT, RunEndT], EndStep[RunEndT]]
84"""A step in the history of a graph run.
86[`Graph.run`][pydantic_graph.graph.Graph.run] returns a list of these steps describing the execution of the graph,
87together with the run return value.
88"""
91nodes_schema_var: ContextVar[Sequence[type[BaseNode[Any, Any, Any]]]] = ContextVar('nodes_var')
94class CustomNodeSchema:
95 def __get_pydantic_core_schema__(
96 self, _source_type: Any, handler: pydantic.GetCoreSchemaHandler
97 ) -> core_schema.CoreSchema:
98 try:
99 nodes = nodes_schema_var.get()
100 except LookupError as e:
101 raise RuntimeError(
102 'Unable to build a Pydantic schema for `NodeStep` or `HistoryStep` without setting `nodes_schema_var`. '
103 'You probably want to use '
104 ) from e
105 if len(nodes) == 1:
106 nodes_type = nodes[0]
107 else:
108 nodes_annotated = [Annotated[node, pydantic.Tag(node.get_id())] for node in nodes]
109 nodes_type = Annotated[Union[tuple(nodes_annotated)], pydantic.Discriminator(self._node_discriminator)]
111 schema = handler(nodes_type)
112 schema['serialization'] = core_schema.wrap_serializer_function_ser_schema(
113 function=self._node_serializer,
114 return_schema=core_schema.dict_schema(core_schema.str_schema(), core_schema.any_schema()),
115 )
116 return schema
118 @staticmethod
119 def _node_discriminator(node_data: Any) -> str:
120 return node_data.get('node_id')
122 @staticmethod
123 def _node_serializer(node: Any, handler: pydantic.SerializerFunctionWrapHandler) -> dict[str, Any]:
124 node_dict = handler(node)
125 node_dict['node_id'] = node.get_id()
126 return node_dict