Coverage for pydantic_graph/pydantic_graph/nodes.py: 96.91%
83 statements
« prev ^ index » next coverage.py v7.6.7, created at 2025-01-25 16:43 +0000
« prev ^ index » next coverage.py v7.6.7, created at 2025-01-25 16:43 +0000
1from __future__ import annotations as _annotations
3from abc import ABC, abstractmethod
4from dataclasses import dataclass, is_dataclass
5from functools import cache
6from typing import TYPE_CHECKING, Any, ClassVar, Generic, get_origin, get_type_hints
8from typing_extensions import Never, TypeVar
10from . import _utils, exceptions
12if TYPE_CHECKING:
13 from .state import StateT
14else:
15 StateT = TypeVar('StateT', default=None)
17__all__ = 'GraphRunContext', 'BaseNode', 'End', 'Edge', 'NodeDef', 'DepsT'
19RunEndT = TypeVar('RunEndT', covariant=True, default=None)
20"""Covariant type variable for the return type of a graph [`run`][pydantic_graph.graph.Graph.run]."""
21NodeRunEndT = TypeVar('NodeRunEndT', covariant=True, default=Never)
22"""Covariant type variable for the return type of a node [`run`][pydantic_graph.nodes.BaseNode.run]."""
23DepsT = TypeVar('DepsT', default=None, contravariant=True)
24"""Type variable for the dependencies of a graph and node."""
27@dataclass
28class GraphRunContext(Generic[StateT, DepsT]):
29 """Context for a graph."""
31 state: StateT
32 """The state of the graph."""
33 deps: DepsT
34 """Dependencies for the graph."""
37class BaseNode(ABC, Generic[StateT, DepsT, NodeRunEndT]):
38 """Base class for a node."""
40 docstring_notes: ClassVar[bool] = False
41 """Set to `True` to generate mermaid diagram notes from the class's docstring.
43 While this can add valuable information to the diagram, it can make diagrams harder to view, hence
44 it is disabled by default. You can also customise notes overriding the
45 [`get_note`][pydantic_graph.nodes.BaseNode.get_note] method.
46 """
48 @abstractmethod
49 async def run(self, ctx: GraphRunContext[StateT, DepsT]) -> BaseNode[StateT, DepsT, Any] | End[NodeRunEndT]:
50 """Run the node.
52 This is an abstract method that must be implemented by subclasses.
54 !!! note "Return types used at runtime"
55 The return type of this method are read by `pydantic_graph` at runtime and used to define which
56 nodes can be called next in the graph. This is displayed in [mermaid diagrams](mermaid.md)
57 and enforced when running the graph.
59 Args:
60 ctx: The graph context.
62 Returns:
63 The next node to run or [`End`][pydantic_graph.nodes.End] to signal the end of the graph.
64 """
65 ...
67 @classmethod
68 @cache
69 def get_id(cls) -> str:
70 """Get the ID of the node."""
71 return cls.__name__
73 @classmethod
74 def get_note(cls) -> str | None:
75 """Get a note about the node to render on mermaid charts.
77 By default, this returns a note only if [`docstring_notes`][pydantic_graph.nodes.BaseNode.docstring_notes]
78 is `True`. You can override this method to customise the node notes.
79 """
80 if not cls.docstring_notes:
81 return None
82 docstring = cls.__doc__
83 # dataclasses get an automatic docstring which is just their signature, we don't want that
84 if docstring and is_dataclass(cls) and docstring.startswith(f'{cls.__name__}('): 84 ↛ 85line 84 didn't jump to line 85 because the condition on line 84 was never true
85 docstring = None
86 if docstring: 86 ↛ 91line 86 didn't jump to line 91 because the condition on line 86 was always true
87 # remove indentation from docstring
88 import inspect
90 docstring = inspect.cleandoc(docstring)
91 return docstring
93 @classmethod
94 def get_node_def(cls, local_ns: dict[str, Any] | None) -> NodeDef[StateT, DepsT, NodeRunEndT]:
95 """Get the node definition."""
96 type_hints = get_type_hints(cls.run, localns=local_ns, include_extras=True)
97 try:
98 return_hint = type_hints['return']
99 except KeyError as e:
100 raise exceptions.GraphSetupError(f'Node {cls} is missing a return type hint on its `run` method') from e
102 next_node_edges: dict[str, Edge] = {}
103 end_edge: Edge | None = None
104 returns_base_node: bool = False
105 for return_type in _utils.get_union_args(return_hint):
106 return_type, annotations = _utils.unpack_annotated(return_type)
107 edge = next((a for a in annotations if isinstance(a, Edge)), Edge(None))
108 return_type_origin = get_origin(return_type) or return_type
109 if return_type_origin is End:
110 end_edge = edge
111 elif return_type_origin is BaseNode:
112 # TODO: Should we disallow this?
113 returns_base_node = True
114 elif issubclass(return_type_origin, BaseNode):
115 next_node_edges[return_type.get_id()] = edge
116 else:
117 raise exceptions.GraphSetupError(f'Invalid return type: {return_type}')
119 return NodeDef(
120 cls,
121 cls.get_id(),
122 cls.get_note(),
123 next_node_edges,
124 end_edge,
125 returns_base_node,
126 )
129@dataclass
130class End(Generic[RunEndT]):
131 """Type to return from a node to signal the end of the graph."""
133 data: RunEndT
134 """Data to return from the graph."""
137@dataclass
138class Edge:
139 """Annotation to apply a label to an edge in a graph."""
141 label: str | None
142 """Label for the edge."""
145@dataclass
146class NodeDef(Generic[StateT, DepsT, NodeRunEndT]):
147 """Definition of a node.
149 This is a primarily internal representation of a node; in general, it shouldn't be necessary to use it directly.
151 Used by [`Graph`][pydantic_graph.graph.Graph] to store information about a node, and when generating
152 mermaid graphs.
153 """
155 node: type[BaseNode[StateT, DepsT, NodeRunEndT]]
156 """The node definition itself."""
157 node_id: str
158 """ID of the node."""
159 note: str | None
160 """Note about the node to render on mermaid charts."""
161 next_node_edges: dict[str, Edge]
162 """IDs of the nodes that can be called next."""
163 end_edge: Edge | None
164 """If node definition returns an `End` this is an Edge, indicating the node can end the run."""
165 returns_base_node: bool
166 """The node definition returns a `BaseNode`, hence any node in the next can be called next."""