Coverage for pydantic_graph/pydantic_graph/nodes.py: 97.69%

110 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-28 17:27 +0000

1from __future__ import annotations as _annotations 

2 

3import copy 

4from abc import ABC, abstractmethod 

5from dataclasses import dataclass, is_dataclass 

6from functools import cache 

7from typing import Any, ClassVar, Generic, get_type_hints 

8from uuid import uuid4 

9 

10from typing_extensions import Never, Self, TypeVar, get_origin 

11 

12from . import _utils, exceptions 

13 

14__all__ = 'GraphRunContext', 'BaseNode', 'End', 'Edge', 'NodeDef', 'DepsT', 'StateT', 'RunEndT' 

15 

16 

17StateT = TypeVar('StateT', default=None) 

18"""Type variable for the state in a graph.""" 

19RunEndT = TypeVar('RunEndT', covariant=True, default=None) 

20"""Covariant type variable for the return type of a graph [`run`][pydantic_graph.graph.Graph.run].""" 

21NodeRunEndT = TypeVar('NodeRunEndT', covariant=True, default=Never) 

22"""Covariant type variable for the return type of a node [`run`][pydantic_graph.nodes.BaseNode.run].""" 

23DepsT = TypeVar('DepsT', default=None, contravariant=True) 

24"""Type variable for the dependencies of a graph and node.""" 

25 

26 

27@dataclass 

28class GraphRunContext(Generic[StateT, DepsT]): 

29 """Context for a graph.""" 

30 

31 # TODO: Can we get rid of this struct and just pass both these things around..? 

32 

33 state: StateT 

34 """The state of the graph.""" 

35 deps: DepsT 

36 """Dependencies for the graph.""" 

37 

38 

39class BaseNode(ABC, Generic[StateT, DepsT, NodeRunEndT]): 

40 """Base class for a node.""" 

41 

42 docstring_notes: ClassVar[bool] = False 

43 """Set to `True` to generate mermaid diagram notes from the class's docstring. 

44 

45 While this can add valuable information to the diagram, it can make diagrams harder to view, hence 

46 it is disabled by default. You can also customise notes overriding the 

47 [`get_note`][pydantic_graph.nodes.BaseNode.get_note] method. 

48 """ 

49 

50 @abstractmethod 

51 async def run(self, ctx: GraphRunContext[StateT, DepsT]) -> BaseNode[StateT, DepsT, Any] | End[NodeRunEndT]: 

52 """Run the node. 

53 

54 This is an abstract method that must be implemented by subclasses. 

55 

56 !!! note "Return types used at runtime" 

57 The return type of this method are read by `pydantic_graph` at runtime and used to define which 

58 nodes can be called next in the graph. This is displayed in [mermaid diagrams](mermaid.md) 

59 and enforced when running the graph. 

60 

61 Args: 

62 ctx: The graph context. 

63 

64 Returns: 

65 The next node to run or [`End`][pydantic_graph.nodes.End] to signal the end of the graph. 

66 """ 

67 ... 

68 

69 def get_snapshot_id(self) -> str: 

70 if snapshot_id := getattr(self, '__snapshot_id', None): 

71 return snapshot_id 

72 else: 

73 self.__dict__['__snapshot_id'] = snapshot_id = generate_snapshot_id(self.get_node_id()) 

74 return snapshot_id 

75 

76 def set_snapshot_id(self, snapshot_id: str) -> None: 

77 self.__dict__['__snapshot_id'] = snapshot_id 

78 

79 @classmethod 

80 @cache 

81 def get_node_id(cls) -> str: 

82 """Get the ID of the node.""" 

83 return cls.__name__ 

84 

85 @classmethod 

86 def get_note(cls) -> str | None: 

87 """Get a note about the node to render on mermaid charts. 

88 

89 By default, this returns a note only if [`docstring_notes`][pydantic_graph.nodes.BaseNode.docstring_notes] 

90 is `True`. You can override this method to customise the node notes. 

91 """ 

92 if not cls.docstring_notes: 

93 return None 

94 docstring = cls.__doc__ 

95 # dataclasses get an automatic docstring which is just their signature, we don't want that 

96 if docstring and is_dataclass(cls) and docstring.startswith(f'{cls.__name__}('): 96 ↛ 97line 96 didn't jump to line 97 because the condition on line 96 was never true

97 docstring = None 

98 if docstring: 98 ↛ 103line 98 didn't jump to line 103 because the condition on line 98 was always true

99 # remove indentation from docstring 

100 import inspect 

101 

102 docstring = inspect.cleandoc(docstring) 

103 return docstring 

104 

105 @classmethod 

106 def get_node_def(cls, local_ns: dict[str, Any] | None) -> NodeDef[StateT, DepsT, NodeRunEndT]: 

107 """Get the node definition.""" 

108 type_hints = get_type_hints(cls.run, localns=local_ns, include_extras=True) 

109 try: 

110 return_hint = type_hints['return'] 

111 except KeyError as e: 

112 raise exceptions.GraphSetupError(f'Node {cls} is missing a return type hint on its `run` method') from e 

113 

114 next_node_edges: dict[str, Edge] = {} 

115 end_edge: Edge | None = None 

116 returns_base_node: bool = False 

117 for return_type in _utils.get_union_args(return_hint): 

118 return_type, annotations = _utils.unpack_annotated(return_type) 

119 edge = next((a for a in annotations if isinstance(a, Edge)), Edge(None)) 

120 return_type_origin = get_origin(return_type) or return_type 

121 if return_type_origin is End: 

122 end_edge = edge 

123 elif return_type_origin is BaseNode: 

124 # TODO: Should we disallow this? 

125 returns_base_node = True 

126 elif issubclass(return_type_origin, BaseNode): 

127 next_node_edges[return_type.get_node_id()] = edge 

128 else: 

129 raise exceptions.GraphSetupError(f'Invalid return type: {return_type}') 

130 

131 return NodeDef( 

132 cls, 

133 cls.get_node_id(), 

134 cls.get_note(), 

135 next_node_edges, 

136 end_edge, 

137 returns_base_node, 

138 ) 

139 

140 def deep_copy(self) -> Self: 

141 """Returns a deep copy of the node.""" 

142 return copy.deepcopy(self) 

143 

144 

145@dataclass 

146class End(Generic[RunEndT]): 

147 """Type to return from a node to signal the end of the graph.""" 

148 

149 data: RunEndT 

150 """Data to return from the graph.""" 

151 

152 def deep_copy_data(self) -> End[RunEndT]: 

153 """Returns a deep copy of the end of the run.""" 

154 if self.data is None: 

155 return self 

156 else: 

157 end = End(copy.deepcopy(self.data)) 

158 end.set_snapshot_id(self.get_snapshot_id()) 

159 return end 

160 

161 def get_snapshot_id(self) -> str: 

162 if snapshot_id := getattr(self, '__snapshot_id', None): 

163 return snapshot_id 

164 else: 

165 self.__dict__['__snapshot_id'] = snapshot_id = generate_snapshot_id('end') 

166 return snapshot_id 

167 

168 def set_snapshot_id(self, set_id: str) -> None: 

169 self.__dict__['__snapshot_id'] = set_id 

170 

171 

172def generate_snapshot_id(node_id: str) -> str: 

173 # module method to allow mocking 

174 return f'{node_id}:{uuid4().hex}' 

175 

176 

177@dataclass 

178class Edge: 

179 """Annotation to apply a label to an edge in a graph.""" 

180 

181 label: str | None 

182 """Label for the edge.""" 

183 

184 

185@dataclass 

186class NodeDef(Generic[StateT, DepsT, NodeRunEndT]): 

187 """Definition of a node. 

188 

189 This is a primarily internal representation of a node; in general, it shouldn't be necessary to use it directly. 

190 

191 Used by [`Graph`][pydantic_graph.graph.Graph] to store information about a node, and when generating 

192 mermaid graphs. 

193 """ 

194 

195 node: type[BaseNode[StateT, DepsT, NodeRunEndT]] 

196 """The node definition itself.""" 

197 node_id: str 

198 """ID of the node.""" 

199 note: str | None 

200 """Note about the node to render on mermaid charts.""" 

201 next_node_edges: dict[str, Edge] 

202 """IDs of the nodes that can be called next.""" 

203 end_edge: Edge | None 

204 """If node definition returns an `End` this is an Edge, indicating the node can end the run.""" 

205 returns_base_node: bool 

206 """The node definition returns a `BaseNode`, hence any node in the next can be called next."""