Coverage for pydantic_graph/pydantic_graph/nodes.py: 96.91%

83 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2025-01-25 16:43 +0000

1from __future__ import annotations as _annotations 

2 

3from abc import ABC, abstractmethod 

4from dataclasses import dataclass, is_dataclass 

5from functools import cache 

6from typing import TYPE_CHECKING, Any, ClassVar, Generic, get_origin, get_type_hints 

7 

8from typing_extensions import Never, TypeVar 

9 

10from . import _utils, exceptions 

11 

12if TYPE_CHECKING: 

13 from .state import StateT 

14else: 

15 StateT = TypeVar('StateT', default=None) 

16 

17__all__ = 'GraphRunContext', 'BaseNode', 'End', 'Edge', 'NodeDef', 'DepsT' 

18 

19RunEndT = TypeVar('RunEndT', covariant=True, default=None) 

20"""Covariant type variable for the return type of a graph [`run`][pydantic_graph.graph.Graph.run].""" 

21NodeRunEndT = TypeVar('NodeRunEndT', covariant=True, default=Never) 

22"""Covariant type variable for the return type of a node [`run`][pydantic_graph.nodes.BaseNode.run].""" 

23DepsT = TypeVar('DepsT', default=None, contravariant=True) 

24"""Type variable for the dependencies of a graph and node.""" 

25 

26 

27@dataclass 

28class GraphRunContext(Generic[StateT, DepsT]): 

29 """Context for a graph.""" 

30 

31 state: StateT 

32 """The state of the graph.""" 

33 deps: DepsT 

34 """Dependencies for the graph.""" 

35 

36 

37class BaseNode(ABC, Generic[StateT, DepsT, NodeRunEndT]): 

38 """Base class for a node.""" 

39 

40 docstring_notes: ClassVar[bool] = False 

41 """Set to `True` to generate mermaid diagram notes from the class's docstring. 

42 

43 While this can add valuable information to the diagram, it can make diagrams harder to view, hence 

44 it is disabled by default. You can also customise notes overriding the 

45 [`get_note`][pydantic_graph.nodes.BaseNode.get_note] method. 

46 """ 

47 

48 @abstractmethod 

49 async def run(self, ctx: GraphRunContext[StateT, DepsT]) -> BaseNode[StateT, DepsT, Any] | End[NodeRunEndT]: 

50 """Run the node. 

51 

52 This is an abstract method that must be implemented by subclasses. 

53 

54 !!! note "Return types used at runtime" 

55 The return type of this method are read by `pydantic_graph` at runtime and used to define which 

56 nodes can be called next in the graph. This is displayed in [mermaid diagrams](mermaid.md) 

57 and enforced when running the graph. 

58 

59 Args: 

60 ctx: The graph context. 

61 

62 Returns: 

63 The next node to run or [`End`][pydantic_graph.nodes.End] to signal the end of the graph. 

64 """ 

65 ... 

66 

67 @classmethod 

68 @cache 

69 def get_id(cls) -> str: 

70 """Get the ID of the node.""" 

71 return cls.__name__ 

72 

73 @classmethod 

74 def get_note(cls) -> str | None: 

75 """Get a note about the node to render on mermaid charts. 

76 

77 By default, this returns a note only if [`docstring_notes`][pydantic_graph.nodes.BaseNode.docstring_notes] 

78 is `True`. You can override this method to customise the node notes. 

79 """ 

80 if not cls.docstring_notes: 

81 return None 

82 docstring = cls.__doc__ 

83 # dataclasses get an automatic docstring which is just their signature, we don't want that 

84 if docstring and is_dataclass(cls) and docstring.startswith(f'{cls.__name__}('): 84 ↛ 85line 84 didn't jump to line 85 because the condition on line 84 was never true

85 docstring = None 

86 if docstring: 86 ↛ 91line 86 didn't jump to line 91 because the condition on line 86 was always true

87 # remove indentation from docstring 

88 import inspect 

89 

90 docstring = inspect.cleandoc(docstring) 

91 return docstring 

92 

93 @classmethod 

94 def get_node_def(cls, local_ns: dict[str, Any] | None) -> NodeDef[StateT, DepsT, NodeRunEndT]: 

95 """Get the node definition.""" 

96 type_hints = get_type_hints(cls.run, localns=local_ns, include_extras=True) 

97 try: 

98 return_hint = type_hints['return'] 

99 except KeyError as e: 

100 raise exceptions.GraphSetupError(f'Node {cls} is missing a return type hint on its `run` method') from e 

101 

102 next_node_edges: dict[str, Edge] = {} 

103 end_edge: Edge | None = None 

104 returns_base_node: bool = False 

105 for return_type in _utils.get_union_args(return_hint): 

106 return_type, annotations = _utils.unpack_annotated(return_type) 

107 edge = next((a for a in annotations if isinstance(a, Edge)), Edge(None)) 

108 return_type_origin = get_origin(return_type) or return_type 

109 if return_type_origin is End: 

110 end_edge = edge 

111 elif return_type_origin is BaseNode: 

112 # TODO: Should we disallow this? 

113 returns_base_node = True 

114 elif issubclass(return_type_origin, BaseNode): 

115 next_node_edges[return_type.get_id()] = edge 

116 else: 

117 raise exceptions.GraphSetupError(f'Invalid return type: {return_type}') 

118 

119 return NodeDef( 

120 cls, 

121 cls.get_id(), 

122 cls.get_note(), 

123 next_node_edges, 

124 end_edge, 

125 returns_base_node, 

126 ) 

127 

128 

129@dataclass 

130class End(Generic[RunEndT]): 

131 """Type to return from a node to signal the end of the graph.""" 

132 

133 data: RunEndT 

134 """Data to return from the graph.""" 

135 

136 

137@dataclass 

138class Edge: 

139 """Annotation to apply a label to an edge in a graph.""" 

140 

141 label: str | None 

142 """Label for the edge.""" 

143 

144 

145@dataclass 

146class NodeDef(Generic[StateT, DepsT, NodeRunEndT]): 

147 """Definition of a node. 

148 

149 This is a primarily internal representation of a node; in general, it shouldn't be necessary to use it directly. 

150 

151 Used by [`Graph`][pydantic_graph.graph.Graph] to store information about a node, and when generating 

152 mermaid graphs. 

153 """ 

154 

155 node: type[BaseNode[StateT, DepsT, NodeRunEndT]] 

156 """The node definition itself.""" 

157 node_id: str 

158 """ID of the node.""" 

159 note: str | None 

160 """Note about the node to render on mermaid charts.""" 

161 next_node_edges: dict[str, Edge] 

162 """IDs of the nodes that can be called next.""" 

163 end_edge: Edge | None 

164 """If node definition returns an `End` this is an Edge, indicating the node can end the run.""" 

165 returns_base_node: bool 

166 """The node definition returns a `BaseNode`, hence any node in the next can be called next."""