Coverage for pydantic_graph/pydantic_graph/nodes.py: 97.69%
110 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-28 17:27 +0000
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-28 17:27 +0000
1from __future__ import annotations as _annotations
3import copy
4from abc import ABC, abstractmethod
5from dataclasses import dataclass, is_dataclass
6from functools import cache
7from typing import Any, ClassVar, Generic, get_type_hints
8from uuid import uuid4
10from typing_extensions import Never, Self, TypeVar, get_origin
12from . import _utils, exceptions
14__all__ = 'GraphRunContext', 'BaseNode', 'End', 'Edge', 'NodeDef', 'DepsT', 'StateT', 'RunEndT'
17StateT = TypeVar('StateT', default=None)
18"""Type variable for the state in a graph."""
19RunEndT = TypeVar('RunEndT', covariant=True, default=None)
20"""Covariant type variable for the return type of a graph [`run`][pydantic_graph.graph.Graph.run]."""
21NodeRunEndT = TypeVar('NodeRunEndT', covariant=True, default=Never)
22"""Covariant type variable for the return type of a node [`run`][pydantic_graph.nodes.BaseNode.run]."""
23DepsT = TypeVar('DepsT', default=None, contravariant=True)
24"""Type variable for the dependencies of a graph and node."""
27@dataclass
28class GraphRunContext(Generic[StateT, DepsT]):
29 """Context for a graph."""
31 # TODO: Can we get rid of this struct and just pass both these things around..?
33 state: StateT
34 """The state of the graph."""
35 deps: DepsT
36 """Dependencies for the graph."""
39class BaseNode(ABC, Generic[StateT, DepsT, NodeRunEndT]):
40 """Base class for a node."""
42 docstring_notes: ClassVar[bool] = False
43 """Set to `True` to generate mermaid diagram notes from the class's docstring.
45 While this can add valuable information to the diagram, it can make diagrams harder to view, hence
46 it is disabled by default. You can also customise notes overriding the
47 [`get_note`][pydantic_graph.nodes.BaseNode.get_note] method.
48 """
50 @abstractmethod
51 async def run(self, ctx: GraphRunContext[StateT, DepsT]) -> BaseNode[StateT, DepsT, Any] | End[NodeRunEndT]:
52 """Run the node.
54 This is an abstract method that must be implemented by subclasses.
56 !!! note "Return types used at runtime"
57 The return type of this method are read by `pydantic_graph` at runtime and used to define which
58 nodes can be called next in the graph. This is displayed in [mermaid diagrams](mermaid.md)
59 and enforced when running the graph.
61 Args:
62 ctx: The graph context.
64 Returns:
65 The next node to run or [`End`][pydantic_graph.nodes.End] to signal the end of the graph.
66 """
67 ...
69 def get_snapshot_id(self) -> str:
70 if snapshot_id := getattr(self, '__snapshot_id', None):
71 return snapshot_id
72 else:
73 self.__dict__['__snapshot_id'] = snapshot_id = generate_snapshot_id(self.get_node_id())
74 return snapshot_id
76 def set_snapshot_id(self, snapshot_id: str) -> None:
77 self.__dict__['__snapshot_id'] = snapshot_id
79 @classmethod
80 @cache
81 def get_node_id(cls) -> str:
82 """Get the ID of the node."""
83 return cls.__name__
85 @classmethod
86 def get_note(cls) -> str | None:
87 """Get a note about the node to render on mermaid charts.
89 By default, this returns a note only if [`docstring_notes`][pydantic_graph.nodes.BaseNode.docstring_notes]
90 is `True`. You can override this method to customise the node notes.
91 """
92 if not cls.docstring_notes:
93 return None
94 docstring = cls.__doc__
95 # dataclasses get an automatic docstring which is just their signature, we don't want that
96 if docstring and is_dataclass(cls) and docstring.startswith(f'{cls.__name__}('): 96 ↛ 97line 96 didn't jump to line 97 because the condition on line 96 was never true
97 docstring = None
98 if docstring: 98 ↛ 103line 98 didn't jump to line 103 because the condition on line 98 was always true
99 # remove indentation from docstring
100 import inspect
102 docstring = inspect.cleandoc(docstring)
103 return docstring
105 @classmethod
106 def get_node_def(cls, local_ns: dict[str, Any] | None) -> NodeDef[StateT, DepsT, NodeRunEndT]:
107 """Get the node definition."""
108 type_hints = get_type_hints(cls.run, localns=local_ns, include_extras=True)
109 try:
110 return_hint = type_hints['return']
111 except KeyError as e:
112 raise exceptions.GraphSetupError(f'Node {cls} is missing a return type hint on its `run` method') from e
114 next_node_edges: dict[str, Edge] = {}
115 end_edge: Edge | None = None
116 returns_base_node: bool = False
117 for return_type in _utils.get_union_args(return_hint):
118 return_type, annotations = _utils.unpack_annotated(return_type)
119 edge = next((a for a in annotations if isinstance(a, Edge)), Edge(None))
120 return_type_origin = get_origin(return_type) or return_type
121 if return_type_origin is End:
122 end_edge = edge
123 elif return_type_origin is BaseNode:
124 # TODO: Should we disallow this?
125 returns_base_node = True
126 elif issubclass(return_type_origin, BaseNode):
127 next_node_edges[return_type.get_node_id()] = edge
128 else:
129 raise exceptions.GraphSetupError(f'Invalid return type: {return_type}')
131 return NodeDef(
132 cls,
133 cls.get_node_id(),
134 cls.get_note(),
135 next_node_edges,
136 end_edge,
137 returns_base_node,
138 )
140 def deep_copy(self) -> Self:
141 """Returns a deep copy of the node."""
142 return copy.deepcopy(self)
145@dataclass
146class End(Generic[RunEndT]):
147 """Type to return from a node to signal the end of the graph."""
149 data: RunEndT
150 """Data to return from the graph."""
152 def deep_copy_data(self) -> End[RunEndT]:
153 """Returns a deep copy of the end of the run."""
154 if self.data is None:
155 return self
156 else:
157 end = End(copy.deepcopy(self.data))
158 end.set_snapshot_id(self.get_snapshot_id())
159 return end
161 def get_snapshot_id(self) -> str:
162 if snapshot_id := getattr(self, '__snapshot_id', None):
163 return snapshot_id
164 else:
165 self.__dict__['__snapshot_id'] = snapshot_id = generate_snapshot_id('end')
166 return snapshot_id
168 def set_snapshot_id(self, set_id: str) -> None:
169 self.__dict__['__snapshot_id'] = set_id
172def generate_snapshot_id(node_id: str) -> str:
173 # module method to allow mocking
174 return f'{node_id}:{uuid4().hex}'
177@dataclass
178class Edge:
179 """Annotation to apply a label to an edge in a graph."""
181 label: str | None
182 """Label for the edge."""
185@dataclass
186class NodeDef(Generic[StateT, DepsT, NodeRunEndT]):
187 """Definition of a node.
189 This is a primarily internal representation of a node; in general, it shouldn't be necessary to use it directly.
191 Used by [`Graph`][pydantic_graph.graph.Graph] to store information about a node, and when generating
192 mermaid graphs.
193 """
195 node: type[BaseNode[StateT, DepsT, NodeRunEndT]]
196 """The node definition itself."""
197 node_id: str
198 """ID of the node."""
199 note: str | None
200 """Note about the node to render on mermaid charts."""
201 next_node_edges: dict[str, Edge]
202 """IDs of the nodes that can be called next."""
203 end_edge: Edge | None
204 """If node definition returns an `End` this is an Edge, indicating the node can end the run."""
205 returns_base_node: bool
206 """The node definition returns a `BaseNode`, hence any node in the next can be called next."""