Coverage for pydantic_graph/pydantic_graph/state.py: 97.33%

71 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2025-01-25 16:43 +0000

1from __future__ import annotations as _annotations 

2 

3import copy 

4from collections.abc import Sequence 

5from contextvars import ContextVar 

6from dataclasses import dataclass, field 

7from datetime import datetime 

8from typing import Annotated, Any, Callable, Generic, Literal, Union 

9 

10import pydantic 

11from pydantic_core import core_schema 

12from typing_extensions import TypeVar 

13 

14from . import _utils 

15from .nodes import BaseNode, End, RunEndT 

16 

17__all__ = 'StateT', 'NodeStep', 'EndStep', 'HistoryStep', 'deep_copy_state', 'nodes_schema_var' 

18 

19 

20StateT = TypeVar('StateT', default=None) 

21"""Type variable for the state in a graph.""" 

22 

23 

24def deep_copy_state(state: StateT) -> StateT: 

25 """Default method for snapshotting the state in a graph run, uses [`copy.deepcopy`][copy.deepcopy].""" 

26 if state is None: 

27 return state 

28 else: 

29 return copy.deepcopy(state) 

30 

31 

32@dataclass 

33class NodeStep(Generic[StateT, RunEndT]): 

34 """History step describing the execution of a node in a graph.""" 

35 

36 state: StateT 

37 """The state of the graph after the node has been run.""" 

38 node: Annotated[BaseNode[StateT, Any, RunEndT], CustomNodeSchema()] 

39 """The node that was run.""" 

40 start_ts: datetime = field(default_factory=_utils.now_utc) 

41 """The timestamp when the node started running.""" 

42 duration: float | None = None 

43 """The duration of the node run in seconds.""" 

44 kind: Literal['node'] = 'node' 

45 """The kind of history step, can be used as a discriminator when deserializing history.""" 

46 # TODO waiting for https://github.com/pydantic/pydantic/issues/11264, should be an InitVar 

47 snapshot_state: Annotated[Callable[[StateT], StateT], pydantic.Field(exclude=True, repr=False)] = field( 

48 default=deep_copy_state, repr=False 

49 ) 

50 """Function to snapshot the state of the graph.""" 

51 

52 def __post_init__(self): 

53 # Copy the state to prevent it from being modified by other code 

54 self.state = self.snapshot_state(self.state) 

55 

56 def data_snapshot(self) -> BaseNode[StateT, Any, RunEndT]: 

57 """Returns a deep copy of [`self.node`][pydantic_graph.state.NodeStep.node]. 

58 

59 Useful for summarizing history. 

60 """ 

61 return copy.deepcopy(self.node) 

62 

63 

64@dataclass 

65class EndStep(Generic[RunEndT]): 

66 """History step describing the end of a graph run.""" 

67 

68 result: End[RunEndT] 

69 """The result of the graph run.""" 

70 ts: datetime = field(default_factory=_utils.now_utc) 

71 """The timestamp when the graph run ended.""" 

72 kind: Literal['end'] = 'end' 

73 """The kind of history step, can be used as a discriminator when deserializing history.""" 

74 

75 def data_snapshot(self) -> End[RunEndT]: 

76 """Returns a deep copy of [`self.result`][pydantic_graph.state.EndStep.result]. 

77 

78 Useful for summarizing history. 

79 """ 

80 return copy.deepcopy(self.result) 

81 

82 

83HistoryStep = Union[NodeStep[StateT, RunEndT], EndStep[RunEndT]] 

84"""A step in the history of a graph run. 

85 

86[`Graph.run`][pydantic_graph.graph.Graph.run] returns a list of these steps describing the execution of the graph, 

87together with the run return value. 

88""" 

89 

90 

91nodes_schema_var: ContextVar[Sequence[type[BaseNode[Any, Any, Any]]]] = ContextVar('nodes_var') 

92 

93 

94class CustomNodeSchema: 

95 def __get_pydantic_core_schema__( 

96 self, _source_type: Any, handler: pydantic.GetCoreSchemaHandler 

97 ) -> core_schema.CoreSchema: 

98 try: 

99 nodes = nodes_schema_var.get() 

100 except LookupError as e: 

101 raise RuntimeError( 

102 'Unable to build a Pydantic schema for `NodeStep` or `HistoryStep` without setting `nodes_schema_var`. ' 

103 'You probably want to use ' 

104 ) from e 

105 if len(nodes) == 1: 

106 nodes_type = nodes[0] 

107 else: 

108 nodes_annotated = [Annotated[node, pydantic.Tag(node.get_id())] for node in nodes] 

109 nodes_type = Annotated[Union[tuple(nodes_annotated)], pydantic.Discriminator(self._node_discriminator)] 

110 

111 schema = handler(nodes_type) 

112 schema['serialization'] = core_schema.wrap_serializer_function_ser_schema( 

113 function=self._node_serializer, 

114 return_schema=core_schema.dict_schema(core_schema.str_schema(), core_schema.any_schema()), 

115 ) 

116 return schema 

117 

118 @staticmethod 

119 def _node_discriminator(node_data: Any) -> str: 

120 return node_data.get('node_id') 

121 

122 @staticmethod 

123 def _node_serializer(node: Any, handler: pydantic.SerializerFunctionWrapHandler) -> dict[str, Any]: 

124 node_dict = handler(node) 

125 node_dict['node_id'] = node.get_id() 

126 return node_dict