Coverage for pydantic_graph/pydantic_graph/graph.py: 97.78%
155 statements
« prev ^ index » next coverage.py v7.6.7, created at 2025-01-25 16:43 +0000
« prev ^ index » next coverage.py v7.6.7, created at 2025-01-25 16:43 +0000
1from __future__ import annotations as _annotations
3import asyncio
4import inspect
5import types
6from collections.abc import Sequence
7from dataclasses import dataclass, field
8from functools import cached_property
9from pathlib import Path
10from time import perf_counter
11from typing import TYPE_CHECKING, Annotated, Any, Callable, Generic, TypeVar
13import logfire_api
14import pydantic
15import typing_extensions
17from . import _utils, exceptions, mermaid
18from .nodes import BaseNode, DepsT, End, GraphRunContext, NodeDef, RunEndT
19from .state import EndStep, HistoryStep, NodeStep, StateT, deep_copy_state, nodes_schema_var
21__all__ = ('Graph',)
23_logfire = logfire_api.Logfire(otel_scope='pydantic-graph')
25T = TypeVar('T')
26"""An invariant typevar."""
29@dataclass(init=False)
30class Graph(Generic[StateT, DepsT, RunEndT]):
31 """Definition of a graph.
33 In `pydantic-graph`, a graph is a collection of nodes that can be run in sequence. The nodes define
34 their outgoing edges — e.g. which nodes may be run next, and thereby the structure of the graph.
36 Here's a very simple example of a graph which increments a number by 1, but makes sure the number is never
37 42 at the end.
39 ```py {title="never_42.py" noqa="I001" py="3.10"}
40 from __future__ import annotations
42 from dataclasses import dataclass
44 from pydantic_graph import BaseNode, End, Graph, GraphRunContext
46 @dataclass
47 class MyState:
48 number: int
50 @dataclass
51 class Increment(BaseNode[MyState]):
52 async def run(self, ctx: GraphRunContext) -> Check42:
53 ctx.state.number += 1
54 return Check42()
56 @dataclass
57 class Check42(BaseNode[MyState, None, int]):
58 async def run(self, ctx: GraphRunContext) -> Increment | End[int]:
59 if ctx.state.number == 42:
60 return Increment()
61 else:
62 return End(ctx.state.number)
64 never_42_graph = Graph(nodes=(Increment, Check42))
65 ```
66 _(This example is complete, it can be run "as is")_
68 See [`run`][pydantic_graph.graph.Graph.run] For an example of running graph, and
69 [`mermaid_code`][pydantic_graph.graph.Graph.mermaid_code] for an example of generating a mermaid diagram
70 from the graph.
71 """
73 name: str | None
74 node_defs: dict[str, NodeDef[StateT, DepsT, RunEndT]]
75 snapshot_state: Callable[[StateT], StateT]
76 _state_type: type[StateT] | _utils.Unset = field(repr=False)
77 _run_end_type: type[RunEndT] | _utils.Unset = field(repr=False)
79 def __init__(
80 self,
81 *,
82 nodes: Sequence[type[BaseNode[StateT, DepsT, RunEndT]]],
83 name: str | None = None,
84 state_type: type[StateT] | _utils.Unset = _utils.UNSET,
85 run_end_type: type[RunEndT] | _utils.Unset = _utils.UNSET,
86 snapshot_state: Callable[[StateT], StateT] = deep_copy_state,
87 ):
88 """Create a graph from a sequence of nodes.
90 Args:
91 nodes: The nodes which make up the graph, nodes need to be unique and all be generic in the same
92 state type.
93 name: Optional name for the graph, if not provided the name will be inferred from the calling frame
94 on the first call to a graph method.
95 state_type: The type of the state for the graph, this can generally be inferred from `nodes`.
96 run_end_type: The type of the result of running the graph, this can generally be inferred from `nodes`.
97 snapshot_state: A function to snapshot the state of the graph, this is used in
98 [`NodeStep`][pydantic_graph.state.NodeStep] and [`EndStep`][pydantic_graph.state.EndStep] to record
99 the state before each step.
100 """
101 self.name = name
102 self._state_type = state_type
103 self._run_end_type = run_end_type
104 self.snapshot_state = snapshot_state
106 parent_namespace = _utils.get_parent_namespace(inspect.currentframe())
107 self.node_defs: dict[str, NodeDef[StateT, DepsT, RunEndT]] = {}
108 for node in nodes:
109 self._register_node(node, parent_namespace)
111 self._validate_edges()
113 async def run(
114 self: Graph[StateT, DepsT, T],
115 start_node: BaseNode[StateT, DepsT, T],
116 *,
117 state: StateT = None,
118 deps: DepsT = None,
119 infer_name: bool = True,
120 ) -> tuple[T, list[HistoryStep[StateT, T]]]:
121 """Run the graph from a starting node until it ends.
123 Args:
124 start_node: the first node to run, since the graph definition doesn't define the entry point in the graph,
125 you need to provide the starting node.
126 state: The initial state of the graph.
127 deps: The dependencies of the graph.
128 infer_name: Whether to infer the graph name from the calling frame.
130 Returns:
131 The result type from ending the run and the history of the run.
133 Here's an example of running the graph from [above][pydantic_graph.graph.Graph]:
135 ```py {title="run_never_42.py" noqa="I001" py="3.10"}
136 from never_42 import Increment, MyState, never_42_graph
138 async def main():
139 state = MyState(1)
140 _, history = await never_42_graph.run(Increment(), state=state)
141 print(state)
142 #> MyState(number=2)
143 print(len(history))
144 #> 3
146 state = MyState(41)
147 _, history = await never_42_graph.run(Increment(), state=state)
148 print(state)
149 #> MyState(number=43)
150 print(len(history))
151 #> 5
152 ```
153 """
154 if infer_name and self.name is None:
155 self._infer_name(inspect.currentframe())
157 history: list[HistoryStep[StateT, T]] = []
158 with _logfire.span(
159 '{graph_name} run {start=}',
160 graph_name=self.name or 'graph',
161 start=start_node,
162 ) as run_span:
163 while True:
164 next_node = await self.next(start_node, history, state=state, deps=deps, infer_name=False)
165 if isinstance(next_node, End):
166 history.append(EndStep(result=next_node))
167 run_span.set_attribute('history', history)
168 return next_node.data, history
169 elif isinstance(next_node, BaseNode):
170 start_node = next_node
171 else:
172 if TYPE_CHECKING:
173 typing_extensions.assert_never(next_node)
174 else:
175 raise exceptions.GraphRuntimeError(
176 f'Invalid node return type: `{type(next_node).__name__}`. Expected `BaseNode` or `End`.'
177 )
179 def run_sync(
180 self: Graph[StateT, DepsT, T],
181 start_node: BaseNode[StateT, DepsT, T],
182 *,
183 state: StateT = None,
184 deps: DepsT = None,
185 infer_name: bool = True,
186 ) -> tuple[T, list[HistoryStep[StateT, T]]]:
187 """Run the graph synchronously.
189 This is a convenience method that wraps [`self.run`][pydantic_graph.Graph.run] with `loop.run_until_complete(...)`.
190 You therefore can't use this method inside async code or if there's an active event loop.
192 Args:
193 start_node: the first node to run, since the graph definition doesn't define the entry point in the graph,
194 you need to provide the starting node.
195 state: The initial state of the graph.
196 deps: The dependencies of the graph.
197 infer_name: Whether to infer the graph name from the calling frame.
199 Returns:
200 The result type from ending the run and the history of the run.
201 """
202 if infer_name and self.name is None: 202 ↛ 204line 202 didn't jump to line 204 because the condition on line 202 was always true
203 self._infer_name(inspect.currentframe())
204 return asyncio.get_event_loop().run_until_complete(
205 self.run(start_node, state=state, deps=deps, infer_name=False)
206 )
208 async def next(
209 self: Graph[StateT, DepsT, T],
210 node: BaseNode[StateT, DepsT, T],
211 history: list[HistoryStep[StateT, T]],
212 *,
213 state: StateT = None,
214 deps: DepsT = None,
215 infer_name: bool = True,
216 ) -> BaseNode[StateT, DepsT, Any] | End[T]:
217 """Run a node in the graph and return the next node to run.
219 Args:
220 node: The node to run.
221 history: The history of the graph run so far. NOTE: this will be mutated to add the new step.
222 state: The current state of the graph.
223 deps: The dependencies of the graph.
224 infer_name: Whether to infer the graph name from the calling frame.
226 Returns:
227 The next node to run or [`End`][pydantic_graph.nodes.End] if the graph has finished.
228 """
229 if infer_name and self.name is None:
230 self._infer_name(inspect.currentframe())
231 node_id = node.get_id()
232 if node_id not in self.node_defs:
233 raise exceptions.GraphRuntimeError(f'Node `{node}` is not in the graph.')
235 ctx = GraphRunContext(state, deps)
236 with _logfire.span('run node {node_id}', node_id=node_id, node=node):
237 start_ts = _utils.now_utc()
238 start = perf_counter()
239 next_node = await node.run(ctx)
240 duration = perf_counter() - start
242 history.append(
243 NodeStep(state=state, node=node, start_ts=start_ts, duration=duration, snapshot_state=self.snapshot_state)
244 )
245 return next_node
247 def dump_history(
248 self: Graph[StateT, DepsT, T], history: list[HistoryStep[StateT, T]], *, indent: int | None = None
249 ) -> bytes:
250 """Dump the history of a graph run as JSON.
252 Args:
253 history: The history of the graph run.
254 indent: The number of spaces to indent the JSON.
256 Returns:
257 The JSON representation of the history.
258 """
259 return self.history_type_adapter.dump_json(history, indent=indent)
261 def load_history(self, json_bytes: str | bytes | bytearray) -> list[HistoryStep[StateT, RunEndT]]:
262 """Load the history of a graph run from JSON.
264 Args:
265 json_bytes: The JSON representation of the history.
267 Returns:
268 The history of the graph run.
269 """
270 return self.history_type_adapter.validate_json(json_bytes)
272 @cached_property
273 def history_type_adapter(self) -> pydantic.TypeAdapter[list[HistoryStep[StateT, RunEndT]]]:
274 nodes = [node_def.node for node_def in self.node_defs.values()]
275 state_t = self._get_state_type()
276 end_t = self._get_run_end_type()
277 token = nodes_schema_var.set(nodes)
278 try:
279 ta = pydantic.TypeAdapter(list[Annotated[HistoryStep[state_t, end_t], pydantic.Discriminator('kind')]])
280 finally:
281 nodes_schema_var.reset(token)
282 return ta
284 def mermaid_code(
285 self,
286 *,
287 start_node: Sequence[mermaid.NodeIdent] | mermaid.NodeIdent | None = None,
288 title: str | None | typing_extensions.Literal[False] = None,
289 edge_labels: bool = True,
290 notes: bool = True,
291 highlighted_nodes: Sequence[mermaid.NodeIdent] | mermaid.NodeIdent | None = None,
292 highlight_css: str = mermaid.DEFAULT_HIGHLIGHT_CSS,
293 infer_name: bool = True,
294 direction: mermaid.StateDiagramDirection | None = None,
295 ) -> str:
296 """Generate a diagram representing the graph as [mermaid](https://mermaid.js.org/) diagram.
298 This method calls [`pydantic_graph.mermaid.generate_code`][pydantic_graph.mermaid.generate_code].
300 Args:
301 start_node: The node or nodes which can start the graph.
302 title: The title of the diagram, use `False` to not include a title.
303 edge_labels: Whether to include edge labels.
304 notes: Whether to include notes on each node.
305 highlighted_nodes: Optional node or nodes to highlight.
306 highlight_css: The CSS to use for highlighting nodes.
307 infer_name: Whether to infer the graph name from the calling frame.
308 direction: The direction of flow.
310 Returns:
311 The mermaid code for the graph, which can then be rendered as a diagram.
313 Here's an example of generating a diagram for the graph from [above][pydantic_graph.graph.Graph]:
315 ```py {title="never_42.py" py="3.10"}
316 from never_42 import Increment, never_42_graph
318 print(never_42_graph.mermaid_code(start_node=Increment))
319 '''
320 ---
321 title: never_42_graph
322 ---
323 stateDiagram-v2
324 [*] --> Increment
325 Increment --> Check42
326 Check42 --> Increment
327 Check42 --> [*]
328 '''
329 ```
331 The rendered diagram will look like this:
333 ```mermaid
334 ---
335 title: never_42_graph
336 ---
337 stateDiagram-v2
338 [*] --> Increment
339 Increment --> Check42
340 Check42 --> Increment
341 Check42 --> [*]
342 ```
343 """
344 if infer_name and self.name is None:
345 self._infer_name(inspect.currentframe())
346 if title is None and self.name:
347 title = self.name
348 return mermaid.generate_code(
349 self,
350 start_node=start_node,
351 highlighted_nodes=highlighted_nodes,
352 highlight_css=highlight_css,
353 title=title or None,
354 edge_labels=edge_labels,
355 notes=notes,
356 direction=direction,
357 )
359 def mermaid_image(
360 self, infer_name: bool = True, **kwargs: typing_extensions.Unpack[mermaid.MermaidConfig]
361 ) -> bytes:
362 """Generate a diagram representing the graph as an image.
364 The format and diagram can be customized using `kwargs`,
365 see [`pydantic_graph.mermaid.MermaidConfig`][pydantic_graph.mermaid.MermaidConfig].
367 !!! note "Uses external service"
368 This method makes a request to [mermaid.ink](https://mermaid.ink) to render the image, `mermaid.ink`
369 is a free service not affiliated with Pydantic.
371 Args:
372 infer_name: Whether to infer the graph name from the calling frame.
373 **kwargs: Additional arguments to pass to `mermaid.request_image`.
375 Returns:
376 The image bytes.
377 """
378 if infer_name and self.name is None:
379 self._infer_name(inspect.currentframe())
380 if 'title' not in kwargs and self.name:
381 kwargs['title'] = self.name
382 return mermaid.request_image(self, **kwargs)
384 def mermaid_save(
385 self, path: Path | str, /, *, infer_name: bool = True, **kwargs: typing_extensions.Unpack[mermaid.MermaidConfig]
386 ) -> None:
387 """Generate a diagram representing the graph and save it as an image.
389 The format and diagram can be customized using `kwargs`,
390 see [`pydantic_graph.mermaid.MermaidConfig`][pydantic_graph.mermaid.MermaidConfig].
392 !!! note "Uses external service"
393 This method makes a request to [mermaid.ink](https://mermaid.ink) to render the image, `mermaid.ink`
394 is a free service not affiliated with Pydantic.
396 Args:
397 path: The path to save the image to.
398 infer_name: Whether to infer the graph name from the calling frame.
399 **kwargs: Additional arguments to pass to `mermaid.save_image`.
400 """
401 if infer_name and self.name is None:
402 self._infer_name(inspect.currentframe())
403 if 'title' not in kwargs and self.name:
404 kwargs['title'] = self.name
405 mermaid.save_image(path, self, **kwargs)
407 def _get_state_type(self) -> type[StateT]:
408 if _utils.is_set(self._state_type):
409 return self._state_type
411 for node_def in self.node_defs.values():
412 for base in typing_extensions.get_original_bases(node_def.node):
413 if typing_extensions.get_origin(base) is BaseNode:
414 args = typing_extensions.get_args(base)
415 if args: 415 ↛ 418line 415 didn't jump to line 418 because the condition on line 415 was always true
416 return args[0]
417 # break the inner (bases) loop
418 break
419 # state defaults to None, so use that if we can't infer it
420 return type(None) # pyright: ignore[reportReturnType]
422 def _get_run_end_type(self) -> type[RunEndT]:
423 if _utils.is_set(self._run_end_type):
424 return self._run_end_type
426 for node_def in self.node_defs.values():
427 for base in typing_extensions.get_original_bases(node_def.node):
428 if typing_extensions.get_origin(base) is BaseNode:
429 args = typing_extensions.get_args(base)
430 if len(args) == 3:
431 t = args[2]
432 if not _utils.is_never(t):
433 return t
434 # break the inner (bases) loop
435 break
436 raise exceptions.GraphSetupError('Could not infer run end type from nodes, please set `run_end_type`.')
438 def _register_node(
439 self: Graph[StateT, DepsT, T],
440 node: type[BaseNode[StateT, DepsT, T]],
441 parent_namespace: dict[str, Any] | None,
442 ) -> None:
443 node_id = node.get_id()
444 if existing_node := self.node_defs.get(node_id):
445 raise exceptions.GraphSetupError(
446 f'Node ID `{node_id}` is not unique — found on {existing_node.node} and {node}'
447 )
448 else:
449 self.node_defs[node_id] = node.get_node_def(parent_namespace)
451 def _validate_edges(self):
452 known_node_ids = self.node_defs.keys()
453 bad_edges: dict[str, list[str]] = {}
455 for node_id, node_def in self.node_defs.items():
456 for edge in node_def.next_node_edges.keys():
457 if edge not in known_node_ids:
458 bad_edges.setdefault(edge, []).append(f'`{node_id}`')
460 if bad_edges:
461 bad_edges_list = [f'`{k}` is referenced by {_utils.comma_and(v)}' for k, v in bad_edges.items()]
462 if len(bad_edges_list) == 1:
463 raise exceptions.GraphSetupError(f'{bad_edges_list[0]} but not included in the graph.')
464 else:
465 b = '\n'.join(f' {be}' for be in bad_edges_list)
466 raise exceptions.GraphSetupError(
467 f'Nodes are referenced in the graph but not included in the graph:\n{b}'
468 )
470 def _infer_name(self, function_frame: types.FrameType | None) -> None:
471 """Infer the agent name from the call frame.
473 Usage should be `self._infer_name(inspect.currentframe())`.
475 Copied from `Agent`.
476 """
477 assert self.name is None, 'Name already set'
478 if function_frame is not None and (parent_frame := function_frame.f_back): # pragma: no branch
479 for name, item in parent_frame.f_locals.items():
480 if item is self:
481 self.name = name
482 return
483 if parent_frame.f_locals != parent_frame.f_globals: 483 ↛ exitline 483 didn't return from function '_infer_name' because the condition on line 483 was always true
484 # if we couldn't find the agent in locals and globals are a different dict, try globals
485 for name, item in parent_frame.f_globals.items(): 485 ↛ exitline 485 didn't return from function '_infer_name' because the loop on line 485 didn't complete
486 if item is self:
487 self.name = name
488 return