Coverage for pydantic_graph/pydantic_graph/graph.py: 97.55%
223 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-28 17:27 +0000
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-28 17:27 +0000
1from __future__ import annotations as _annotations
3import inspect
4import types
5from collections.abc import AsyncIterator, Sequence
6from contextlib import AbstractContextManager, ExitStack, asynccontextmanager
7from dataclasses import dataclass, field
8from functools import cached_property
9from typing import Any, Generic, cast
11import logfire_api
12import typing_extensions
13from logfire_api import LogfireSpan
14from typing_extensions import deprecated
15from typing_inspection import typing_objects
17from . import _utils, exceptions, mermaid
18from .nodes import BaseNode, DepsT, End, GraphRunContext, NodeDef, RunEndT, StateT
19from .persistence import BaseStatePersistence
20from .persistence.in_mem import SimpleStatePersistence
22# while waiting for https://github.com/pydantic/logfire/issues/745
23try:
24 import logfire._internal.stack_info
25except ImportError:
26 pass
27else:
28 from pathlib import Path
30 logfire._internal.stack_info.NON_USER_CODE_PREFIXES += (str(Path(__file__).parent.absolute()),)
33__all__ = 'Graph', 'GraphRun', 'GraphRunResult'
35_logfire = logfire_api.Logfire(otel_scope='pydantic-graph')
38@dataclass(init=False)
39class Graph(Generic[StateT, DepsT, RunEndT]):
40 """Definition of a graph.
42 In `pydantic-graph`, a graph is a collection of nodes that can be run in sequence. The nodes define
43 their outgoing edges — e.g. which nodes may be run next, and thereby the structure of the graph.
45 Here's a very simple example of a graph which increments a number by 1, but makes sure the number is never
46 42 at the end.
48 ```py {title="never_42.py" noqa="I001" py="3.10"}
49 from __future__ import annotations
51 from dataclasses import dataclass
53 from pydantic_graph import BaseNode, End, Graph, GraphRunContext
55 @dataclass
56 class MyState:
57 number: int
59 @dataclass
60 class Increment(BaseNode[MyState]):
61 async def run(self, ctx: GraphRunContext) -> Check42:
62 ctx.state.number += 1
63 return Check42()
65 @dataclass
66 class Check42(BaseNode[MyState, None, int]):
67 async def run(self, ctx: GraphRunContext) -> Increment | End[int]:
68 if ctx.state.number == 42:
69 return Increment()
70 else:
71 return End(ctx.state.number)
73 never_42_graph = Graph(nodes=(Increment, Check42))
74 ```
75 _(This example is complete, it can be run "as is")_
77 See [`run`][pydantic_graph.graph.Graph.run] For an example of running graph, and
78 [`mermaid_code`][pydantic_graph.graph.Graph.mermaid_code] for an example of generating a mermaid diagram
79 from the graph.
80 """
82 name: str | None
83 node_defs: dict[str, NodeDef[StateT, DepsT, RunEndT]]
84 _state_type: type[StateT] | _utils.Unset = field(repr=False)
85 _run_end_type: type[RunEndT] | _utils.Unset = field(repr=False)
86 auto_instrument: bool = field(repr=False)
88 def __init__(
89 self,
90 *,
91 nodes: Sequence[type[BaseNode[StateT, DepsT, RunEndT]]],
92 name: str | None = None,
93 state_type: type[StateT] | _utils.Unset = _utils.UNSET,
94 run_end_type: type[RunEndT] | _utils.Unset = _utils.UNSET,
95 auto_instrument: bool = True,
96 ):
97 """Create a graph from a sequence of nodes.
99 Args:
100 nodes: The nodes which make up the graph, nodes need to be unique and all be generic in the same
101 state type.
102 name: Optional name for the graph, if not provided the name will be inferred from the calling frame
103 on the first call to a graph method.
104 state_type: The type of the state for the graph, this can generally be inferred from `nodes`.
105 run_end_type: The type of the result of running the graph, this can generally be inferred from `nodes`.
106 auto_instrument: Whether to create a span for the graph run and the execution of each node's run method.
107 """
108 self.name = name
109 self._state_type = state_type
110 self._run_end_type = run_end_type
111 self.auto_instrument = auto_instrument
113 parent_namespace = _utils.get_parent_namespace(inspect.currentframe())
114 self.node_defs: dict[str, NodeDef[StateT, DepsT, RunEndT]] = {}
115 for node in nodes:
116 self._register_node(node, parent_namespace)
118 self._validate_edges()
120 async def run(
121 self,
122 start_node: BaseNode[StateT, DepsT, RunEndT],
123 *,
124 state: StateT = None,
125 deps: DepsT = None,
126 persistence: BaseStatePersistence[StateT, RunEndT] | None = None,
127 infer_name: bool = True,
128 span: LogfireSpan | None = None,
129 ) -> GraphRunResult[StateT, RunEndT]:
130 """Run the graph from a starting node until it ends.
132 Args:
133 start_node: the first node to run, since the graph definition doesn't define the entry point in the graph,
134 you need to provide the starting node.
135 state: The initial state of the graph.
136 deps: The dependencies of the graph.
137 persistence: State persistence interface, defaults to
138 [`SimpleStatePersistence`][pydantic_graph.SimpleStatePersistence] if `None`.
139 infer_name: Whether to infer the graph name from the calling frame.
140 span: The span to use for the graph run. If not provided, a span will be created depending on the value of
141 the `auto_instrument` field.
143 Returns:
144 A `GraphRunResult` containing information about the run, including its final result.
146 Here's an example of running the graph from [above][pydantic_graph.graph.Graph]:
148 ```py {title="run_never_42.py" noqa="I001" py="3.10"}
149 from never_42 import Increment, MyState, never_42_graph
151 async def main():
152 state = MyState(1)
153 await never_42_graph.run(Increment(), state=state)
154 print(state)
155 #> MyState(number=2)
157 state = MyState(41)
158 await never_42_graph.run(Increment(), state=state)
159 print(state)
160 #> MyState(number=43)
161 ```
162 """
163 if infer_name and self.name is None:
164 self._infer_name(inspect.currentframe())
166 async with self.iter(
167 start_node, state=state, deps=deps, persistence=persistence, span=span, infer_name=False
168 ) as graph_run:
169 async for _node in graph_run:
170 pass
172 final_result = graph_run.result
173 assert final_result is not None, 'GraphRun should have a final result'
174 return final_result
176 def run_sync(
177 self,
178 start_node: BaseNode[StateT, DepsT, RunEndT],
179 *,
180 state: StateT = None,
181 deps: DepsT = None,
182 persistence: BaseStatePersistence[StateT, RunEndT] | None = None,
183 infer_name: bool = True,
184 ) -> GraphRunResult[StateT, RunEndT]:
185 """Synchronously run the graph.
187 This is a convenience method that wraps [`self.run`][pydantic_graph.Graph.run] with `loop.run_until_complete(...)`.
188 You therefore can't use this method inside async code or if there's an active event loop.
190 Args:
191 start_node: the first node to run, since the graph definition doesn't define the entry point in the graph,
192 you need to provide the starting node.
193 state: The initial state of the graph.
194 deps: The dependencies of the graph.
195 persistence: State persistence interface, defaults to
196 [`SimpleStatePersistence`][pydantic_graph.SimpleStatePersistence] if `None`.
197 infer_name: Whether to infer the graph name from the calling frame.
199 Returns:
200 The result type from ending the run and the history of the run.
201 """
202 if infer_name and self.name is None: 202 ↛ 205line 202 didn't jump to line 205 because the condition on line 202 was always true
203 self._infer_name(inspect.currentframe())
205 return _utils.get_event_loop().run_until_complete(
206 self.run(start_node, state=state, deps=deps, persistence=persistence, infer_name=False)
207 )
209 @asynccontextmanager
210 async def iter(
211 self,
212 start_node: BaseNode[StateT, DepsT, RunEndT],
213 *,
214 state: StateT = None,
215 deps: DepsT = None,
216 persistence: BaseStatePersistence[StateT, RunEndT] | None = None,
217 span: AbstractContextManager[Any] | None = None,
218 infer_name: bool = True,
219 ) -> AsyncIterator[GraphRun[StateT, DepsT, RunEndT]]:
220 """A contextmanager which can be used to iterate over the graph's nodes as they are executed.
222 This method returns a `GraphRun` object which can be used to async-iterate over the nodes of this `Graph` as
223 they are executed. This is the API to use if you want to record or interact with the nodes as the graph
224 execution unfolds.
226 The `GraphRun` can also be used to manually drive the graph execution by calling
227 [`GraphRun.next`][pydantic_graph.graph.GraphRun.next].
229 The `GraphRun` provides access to the full run history, state, deps, and the final result of the run once
230 it has completed.
232 For more details, see the API documentation of [`GraphRun`][pydantic_graph.graph.GraphRun].
234 Args:
235 start_node: the first node to run. Since the graph definition doesn't define the entry point in the graph,
236 you need to provide the starting node.
237 state: The initial state of the graph.
238 deps: The dependencies of the graph.
239 persistence: State persistence interface, defaults to
240 [`SimpleStatePersistence`][pydantic_graph.SimpleStatePersistence] if `None`.
241 span: The span to use for the graph run. If not provided, a new span will be created.
242 infer_name: Whether to infer the graph name from the calling frame.
244 Returns: A GraphRun that can be async iterated over to drive the graph to completion.
245 """
246 if infer_name and self.name is None:
247 # f_back because `asynccontextmanager` adds one frame
248 if frame := inspect.currentframe(): # pragma: no branch
249 self._infer_name(frame.f_back)
251 if persistence is None:
252 persistence = SimpleStatePersistence()
253 persistence.set_graph_types(self)
255 if self.auto_instrument and span is None:
256 span = logfire_api.span('run graph {graph.name}', graph=self)
258 with ExitStack() as stack:
259 if span is not None: 259 ↛ 261line 259 didn't jump to line 261 because the condition on line 259 was always true
260 stack.enter_context(span)
261 yield GraphRun[StateT, DepsT, RunEndT](
262 graph=self, start_node=start_node, persistence=persistence, state=state, deps=deps
263 )
265 @asynccontextmanager
266 async def iter_from_persistence(
267 self,
268 persistence: BaseStatePersistence[StateT, RunEndT],
269 *,
270 deps: DepsT = None,
271 span: AbstractContextManager[Any] | None = None,
272 infer_name: bool = True,
273 ) -> AsyncIterator[GraphRun[StateT, DepsT, RunEndT]]:
274 """A contextmanager to iterate over the graph's nodes as they are executed, created from a persistence object.
276 This method has similar functionality to [`iter`][pydantic_graph.graph.Graph.iter],
277 but instead of passing the node to run, it will restore the node and state from state persistence.
279 Args:
280 persistence: The state persistence interface to use.
281 deps: The dependencies of the graph.
282 span: The span to use for the graph run. If not provided, a new span will be created.
283 infer_name: Whether to infer the graph name from the calling frame.
285 Returns: A GraphRun that can be async iterated over to drive the graph to completion.
286 """
287 if infer_name and self.name is None:
288 # f_back because `asynccontextmanager` adds one frame
289 if frame := inspect.currentframe(): # pragma: no branch
290 self._infer_name(frame.f_back)
292 persistence.set_graph_types(self)
294 snapshot = await persistence.load_next()
295 if snapshot is None:
296 raise exceptions.GraphRuntimeError('Unable to restore snapshot from state persistence.')
298 snapshot.node.set_snapshot_id(snapshot.id)
300 if self.auto_instrument and span is None: 300 ↛ 303line 300 didn't jump to line 303 because the condition on line 300 was always true
301 span = logfire_api.span('run graph {graph.name}', graph=self)
303 with ExitStack() as stack:
304 if span is not None: 304 ↛ 306line 304 didn't jump to line 306 because the condition on line 304 was always true
305 stack.enter_context(span)
306 yield GraphRun[StateT, DepsT, RunEndT](
307 graph=self,
308 start_node=snapshot.node,
309 persistence=persistence,
310 state=snapshot.state,
311 deps=deps,
312 snapshot_id=snapshot.id,
313 )
315 async def initialize(
316 self,
317 node: BaseNode[StateT, DepsT, RunEndT],
318 persistence: BaseStatePersistence[StateT, RunEndT],
319 *,
320 state: StateT = None,
321 infer_name: bool = True,
322 ) -> None:
323 """Initialize a new graph run in persistence without running it.
325 This is useful if you want to set up a graph run to be run later, e.g. via
326 [`iter_from_persistence`][pydantic_graph.graph.Graph.iter_from_persistence].
328 Args:
329 node: The node to run first.
330 persistence: State persistence interface.
331 state: The start state of the graph.
332 infer_name: Whether to infer the graph name from the calling frame.
333 """
334 if infer_name and self.name is None:
335 self._infer_name(inspect.currentframe())
337 persistence.set_graph_types(self)
338 await persistence.snapshot_node(state, node)
340 @deprecated('`next` is deprecated, use `async with graph.iter(...) as run: run.next()` instead')
341 async def next(
342 self,
343 node: BaseNode[StateT, DepsT, RunEndT],
344 persistence: BaseStatePersistence[StateT, RunEndT],
345 *,
346 state: StateT = None,
347 deps: DepsT = None,
348 infer_name: bool = True,
349 ) -> BaseNode[StateT, DepsT, Any] | End[RunEndT]:
350 """Run a node in the graph and return the next node to run.
352 Args:
353 node: The node to run.
354 persistence: State persistence interface, defaults to
355 [`SimpleStatePersistence`][pydantic_graph.SimpleStatePersistence] if `None`.
356 state: The current state of the graph.
357 deps: The dependencies of the graph.
358 infer_name: Whether to infer the graph name from the calling frame.
360 Returns:
361 The next node to run or [`End`][pydantic_graph.nodes.End] if the graph has finished.
362 """
363 if infer_name and self.name is None: 363 ↛ 366line 363 didn't jump to line 366 because the condition on line 363 was always true
364 self._infer_name(inspect.currentframe())
366 persistence.set_graph_types(self)
367 run = GraphRun[StateT, DepsT, RunEndT](
368 graph=self,
369 start_node=node,
370 persistence=persistence,
371 state=state,
372 deps=deps,
373 )
374 return await run.next(node)
376 def mermaid_code(
377 self,
378 *,
379 start_node: Sequence[mermaid.NodeIdent] | mermaid.NodeIdent | None = None,
380 title: str | None | typing_extensions.Literal[False] = None,
381 edge_labels: bool = True,
382 notes: bool = True,
383 highlighted_nodes: Sequence[mermaid.NodeIdent] | mermaid.NodeIdent | None = None,
384 highlight_css: str = mermaid.DEFAULT_HIGHLIGHT_CSS,
385 infer_name: bool = True,
386 direction: mermaid.StateDiagramDirection | None = None,
387 ) -> str:
388 """Generate a diagram representing the graph as [mermaid](https://mermaid.js.org/) diagram.
390 This method calls [`pydantic_graph.mermaid.generate_code`][pydantic_graph.mermaid.generate_code].
392 Args:
393 start_node: The node or nodes which can start the graph.
394 title: The title of the diagram, use `False` to not include a title.
395 edge_labels: Whether to include edge labels.
396 notes: Whether to include notes on each node.
397 highlighted_nodes: Optional node or nodes to highlight.
398 highlight_css: The CSS to use for highlighting nodes.
399 infer_name: Whether to infer the graph name from the calling frame.
400 direction: The direction of flow.
402 Returns:
403 The mermaid code for the graph, which can then be rendered as a diagram.
405 Here's an example of generating a diagram for the graph from [above][pydantic_graph.graph.Graph]:
407 ```py {title="mermaid_never_42.py" py="3.10"}
408 from never_42 import Increment, never_42_graph
410 print(never_42_graph.mermaid_code(start_node=Increment))
411 '''
412 ---
413 title: never_42_graph
414 ---
415 stateDiagram-v2
416 [*] --> Increment
417 Increment --> Check42
418 Check42 --> Increment
419 Check42 --> [*]
420 '''
421 ```
423 The rendered diagram will look like this:
425 ```mermaid
426 ---
427 title: never_42_graph
428 ---
429 stateDiagram-v2
430 [*] --> Increment
431 Increment --> Check42
432 Check42 --> Increment
433 Check42 --> [*]
434 ```
435 """
436 if infer_name and self.name is None:
437 self._infer_name(inspect.currentframe())
438 if title is None and self.name:
439 title = self.name
440 return mermaid.generate_code(
441 self,
442 start_node=start_node,
443 highlighted_nodes=highlighted_nodes,
444 highlight_css=highlight_css,
445 title=title or None,
446 edge_labels=edge_labels,
447 notes=notes,
448 direction=direction,
449 )
451 def mermaid_image(
452 self, infer_name: bool = True, **kwargs: typing_extensions.Unpack[mermaid.MermaidConfig]
453 ) -> bytes:
454 """Generate a diagram representing the graph as an image.
456 The format and diagram can be customized using `kwargs`,
457 see [`pydantic_graph.mermaid.MermaidConfig`][pydantic_graph.mermaid.MermaidConfig].
459 !!! note "Uses external service"
460 This method makes a request to [mermaid.ink](https://mermaid.ink) to render the image, `mermaid.ink`
461 is a free service not affiliated with Pydantic.
463 Args:
464 infer_name: Whether to infer the graph name from the calling frame.
465 **kwargs: Additional arguments to pass to `mermaid.request_image`.
467 Returns:
468 The image bytes.
469 """
470 if infer_name and self.name is None:
471 self._infer_name(inspect.currentframe())
472 if 'title' not in kwargs and self.name:
473 kwargs['title'] = self.name
474 return mermaid.request_image(self, **kwargs)
476 def mermaid_save(
477 self, path: Path | str, /, *, infer_name: bool = True, **kwargs: typing_extensions.Unpack[mermaid.MermaidConfig]
478 ) -> None:
479 """Generate a diagram representing the graph and save it as an image.
481 The format and diagram can be customized using `kwargs`,
482 see [`pydantic_graph.mermaid.MermaidConfig`][pydantic_graph.mermaid.MermaidConfig].
484 !!! note "Uses external service"
485 This method makes a request to [mermaid.ink](https://mermaid.ink) to render the image, `mermaid.ink`
486 is a free service not affiliated with Pydantic.
488 Args:
489 path: The path to save the image to.
490 infer_name: Whether to infer the graph name from the calling frame.
491 **kwargs: Additional arguments to pass to `mermaid.save_image`.
492 """
493 if infer_name and self.name is None:
494 self._infer_name(inspect.currentframe())
495 if 'title' not in kwargs and self.name:
496 kwargs['title'] = self.name
497 mermaid.save_image(path, self, **kwargs)
499 def get_nodes(self) -> Sequence[type[BaseNode[StateT, DepsT, RunEndT]]]:
500 """Get the nodes in the graph."""
501 return [node_def.node for node_def in self.node_defs.values()]
503 @cached_property
504 def inferred_types(self) -> tuple[type[StateT], type[RunEndT]]:
505 # Get the types of the state and run end from the graph.
506 if _utils.is_set(self._state_type) and _utils.is_set(self._run_end_type):
507 return self._state_type, self._run_end_type
509 state_type = self._state_type
510 run_end_type = self._run_end_type
512 for node_def in self.node_defs.values():
513 for base in typing_extensions.get_original_bases(node_def.node):
514 if typing_extensions.get_origin(base) is BaseNode:
515 args = typing_extensions.get_args(base)
516 if not _utils.is_set(state_type) and args:
517 state_type = args[0]
519 if not _utils.is_set(run_end_type) and len(args) == 3:
520 t = args[2]
521 if not typing_objects.is_never(t):
522 run_end_type = t
523 if _utils.is_set(state_type) and _utils.is_set(run_end_type):
524 return state_type, run_end_type # pyright: ignore[reportReturnType]
525 # break the inner (bases) loop
526 break
528 if not _utils.is_set(state_type): 528 ↛ 531line 528 didn't jump to line 531 because the condition on line 528 was always true
529 # state defaults to None, so use that if we can't infer it
530 state_type = None
531 if not _utils.is_set(run_end_type):
532 # this happens if a graph has no return nodes, use None so any downstream errors are clear
533 run_end_type = None
534 return state_type, run_end_type # pyright: ignore[reportReturnType]
536 def _register_node(
537 self,
538 node: type[BaseNode[StateT, DepsT, RunEndT]],
539 parent_namespace: dict[str, Any] | None,
540 ) -> None:
541 node_id = node.get_node_id()
542 if existing_node := self.node_defs.get(node_id):
543 raise exceptions.GraphSetupError(
544 f'Node ID `{node_id}` is not unique — found on {existing_node.node} and {node}'
545 )
546 else:
547 self.node_defs[node_id] = node.get_node_def(parent_namespace)
549 def _validate_edges(self):
550 known_node_ids = self.node_defs.keys()
551 bad_edges: dict[str, list[str]] = {}
553 for node_id, node_def in self.node_defs.items():
554 for edge in node_def.next_node_edges.keys():
555 if edge not in known_node_ids:
556 bad_edges.setdefault(edge, []).append(f'`{node_id}`')
558 if bad_edges:
559 bad_edges_list = [f'`{k}` is referenced by {_utils.comma_and(v)}' for k, v in bad_edges.items()]
560 if len(bad_edges_list) == 1:
561 raise exceptions.GraphSetupError(f'{bad_edges_list[0]} but not included in the graph.')
562 else:
563 b = '\n'.join(f' {be}' for be in bad_edges_list)
564 raise exceptions.GraphSetupError(
565 f'Nodes are referenced in the graph but not included in the graph:\n{b}'
566 )
568 def _infer_name(self, function_frame: types.FrameType | None) -> None:
569 """Infer the agent name from the call frame.
571 Usage should be `self._infer_name(inspect.currentframe())`.
573 Copied from `Agent`.
574 """
575 assert self.name is None, 'Name already set'
576 if function_frame is not None and (parent_frame := function_frame.f_back): # pragma: no branch
577 for name, item in parent_frame.f_locals.items():
578 if item is self:
579 self.name = name
580 return
581 if parent_frame.f_locals != parent_frame.f_globals: 581 ↛ exitline 581 didn't return from function '_infer_name' because the condition on line 581 was always true
582 # if we couldn't find the agent in locals and globals are a different dict, try globals
583 for name, item in parent_frame.f_globals.items(): 583 ↛ exitline 583 didn't return from function '_infer_name' because the loop on line 583 didn't complete
584 if item is self:
585 self.name = name
586 return
589class GraphRun(Generic[StateT, DepsT, RunEndT]):
590 """A stateful, async-iterable run of a [`Graph`][pydantic_graph.graph.Graph].
592 You typically get a `GraphRun` instance from calling
593 `async with [my_graph.iter(...)][pydantic_graph.graph.Graph.iter] as graph_run:`. That gives you the ability to iterate
594 through nodes as they run, either by `async for` iteration or by repeatedly calling `.next(...)`.
596 Here's an example of iterating over the graph from [above][pydantic_graph.graph.Graph]:
597 ```py {title="iter_never_42.py" noqa="I001" py="3.10"}
598 from copy import deepcopy
599 from never_42 import Increment, MyState, never_42_graph
601 async def main():
602 state = MyState(1)
603 async with never_42_graph.iter(Increment(), state=state) as graph_run:
604 node_states = [(graph_run.next_node, deepcopy(graph_run.state))]
605 async for node in graph_run:
606 node_states.append((node, deepcopy(graph_run.state)))
607 print(node_states)
608 '''
609 [
610 (Increment(), MyState(number=1)),
611 (Check42(), MyState(number=2)),
612 (End(data=2), MyState(number=2)),
613 ]
614 '''
616 state = MyState(41)
617 async with never_42_graph.iter(Increment(), state=state) as graph_run:
618 node_states = [(graph_run.next_node, deepcopy(graph_run.state))]
619 async for node in graph_run:
620 node_states.append((node, deepcopy(graph_run.state)))
621 print(node_states)
622 '''
623 [
624 (Increment(), MyState(number=41)),
625 (Check42(), MyState(number=42)),
626 (Increment(), MyState(number=42)),
627 (Check42(), MyState(number=43)),
628 (End(data=43), MyState(number=43)),
629 ]
630 '''
631 ```
633 See the [`GraphRun.next` documentation][pydantic_graph.graph.GraphRun.next] for an example of how to manually
634 drive the graph run.
635 """
637 def __init__(
638 self,
639 *,
640 graph: Graph[StateT, DepsT, RunEndT],
641 start_node: BaseNode[StateT, DepsT, RunEndT],
642 persistence: BaseStatePersistence[StateT, RunEndT],
643 state: StateT,
644 deps: DepsT,
645 snapshot_id: str | None = None,
646 ):
647 """Create a new run for a given graph, starting at the specified node.
649 Typically, you'll use [`Graph.iter`][pydantic_graph.graph.Graph.iter] rather than calling this directly.
651 Args:
652 graph: The [`Graph`][pydantic_graph.graph.Graph] to run.
653 start_node: The node where execution will begin.
654 persistence: State persistence interface.
655 state: A shared state object or primitive (like a counter, dataclass, etc.) that is available
656 to all nodes via `ctx.state`.
657 deps: Optional dependencies that each node can access via `ctx.deps`, e.g. database connections,
658 configuration, or logging clients.
659 snapshot_id: The ID of the snapshot the node came from.
660 """
661 self.graph = graph
662 self.persistence = persistence
663 self._snapshot_id: str | None = snapshot_id
664 self.state = state
665 self.deps = deps
667 self._next_node: BaseNode[StateT, DepsT, RunEndT] | End[RunEndT] = start_node
669 @property
670 def next_node(self) -> BaseNode[StateT, DepsT, RunEndT] | End[RunEndT]:
671 """The next node that will be run in the graph.
673 This is the next node that will be used during async iteration, or if a node is not passed to `self.next(...)`.
674 """
675 return self._next_node
677 @property
678 def result(self) -> GraphRunResult[StateT, RunEndT] | None:
679 """The final result of the graph run if the run is completed, otherwise `None`."""
680 if not isinstance(self._next_node, End):
681 return None # The GraphRun has not finished running
682 return GraphRunResult(
683 self._next_node.data,
684 state=self.state,
685 persistence=self.persistence,
686 )
688 async def next(
689 self, node: BaseNode[StateT, DepsT, RunEndT] | None = None
690 ) -> BaseNode[StateT, DepsT, RunEndT] | End[RunEndT]:
691 """Manually drive the graph run by passing in the node you want to run next.
693 This lets you inspect or mutate the node before continuing execution, or skip certain nodes
694 under dynamic conditions. The graph run should stop when you return an [`End`][pydantic_graph.nodes.End] node.
696 Here's an example of using `next` to drive the graph from [above][pydantic_graph.graph.Graph]:
697 ```py {title="next_never_42.py" noqa="I001" py="3.10"}
698 from copy import deepcopy
699 from pydantic_graph import End
700 from never_42 import Increment, MyState, never_42_graph
702 async def main():
703 state = MyState(48)
704 async with never_42_graph.iter(Increment(), state=state) as graph_run:
705 next_node = graph_run.next_node # start with the first node
706 node_states = [(next_node, deepcopy(graph_run.state))]
708 while not isinstance(next_node, End):
709 if graph_run.state.number == 50:
710 graph_run.state.number = 42
711 next_node = await graph_run.next(next_node)
712 node_states.append((next_node, deepcopy(graph_run.state)))
714 print(node_states)
715 '''
716 [
717 (Increment(), MyState(number=48)),
718 (Check42(), MyState(number=49)),
719 (End(data=49), MyState(number=49)),
720 ]
721 '''
722 ```
724 Args:
725 node: The node to run next in the graph. If not specified, uses `self.next_node`, which is initialized to
726 the `start_node` of the run and updated each time a new node is returned.
728 Returns:
729 The next node returned by the graph logic, or an [`End`][pydantic_graph.nodes.End] node if
730 the run has completed.
731 """
732 if node is None:
733 # This cast is necessary because self._next_node could be an `End`. You'll get a runtime error if that's
734 # the case, but if it is, the only way to get there would be to have tried calling next manually after
735 # the run finished. Either way, maybe it would be better to not do this cast...
736 node = cast(BaseNode[StateT, DepsT, RunEndT], self._next_node)
737 node_snapshot_id = node.get_snapshot_id()
738 else:
739 node_snapshot_id = node.get_snapshot_id()
741 if node_snapshot_id != self._snapshot_id:
742 await self.persistence.snapshot_node_if_new(node_snapshot_id, self.state, node)
743 self._snapshot_id = node_snapshot_id
745 if not isinstance(node, BaseNode):
746 # While technically this is not compatible with the documented method signature, it's an easy mistake to
747 # make, and we should eagerly provide a more helpful error message than you'd get otherwise.
748 raise TypeError(f'`next` must be called with a `BaseNode` instance, got {node!r}.')
750 node_id = node.get_node_id()
751 if node_id not in self.graph.node_defs:
752 raise exceptions.GraphRuntimeError(f'Node `{node}` is not in the graph.')
754 with ExitStack() as stack:
755 if self.graph.auto_instrument:
756 stack.enter_context(_logfire.span('run node {node_id}', node_id=node_id, node=node))
758 async with self.persistence.record_run(node_snapshot_id):
759 ctx = GraphRunContext(self.state, self.deps)
760 self._next_node = await node.run(ctx)
762 if isinstance(self._next_node, End):
763 self._snapshot_id = self._next_node.get_snapshot_id()
764 await self.persistence.snapshot_end(self.state, self._next_node)
765 elif isinstance(self._next_node, BaseNode):
766 self._snapshot_id = self._next_node.get_snapshot_id()
767 await self.persistence.snapshot_node(self.state, self._next_node)
768 else:
769 raise exceptions.GraphRuntimeError(
770 f'Invalid node return type: `{type(self._next_node).__name__}`. Expected `BaseNode` or `End`.'
771 )
773 return self._next_node
775 def __aiter__(self) -> AsyncIterator[BaseNode[StateT, DepsT, RunEndT] | End[RunEndT]]:
776 return self
778 async def __anext__(self) -> BaseNode[StateT, DepsT, RunEndT] | End[RunEndT]:
779 """Use the last returned node as the input to `Graph.next`."""
780 if isinstance(self._next_node, End):
781 raise StopAsyncIteration
782 return await self.next(self._next_node)
784 def __repr__(self) -> str:
785 return f'<GraphRun graph={self.graph.name or "[unnamed]"}>'
788@dataclass
789class GraphRunResult(Generic[StateT, RunEndT]):
790 """The final result of running a graph."""
792 output: RunEndT
793 state: StateT
794 persistence: BaseStatePersistence[StateT, RunEndT] = field(repr=False)