Coverage for pydantic_graph/pydantic_graph/graph.py: 97.97%
170 statements
« prev ^ index » next coverage.py v7.6.7, created at 2025-01-30 19:21 +0000
« prev ^ index » next coverage.py v7.6.7, created at 2025-01-30 19:21 +0000
1from __future__ import annotations as _annotations
3import asyncio
4import inspect
5import types
6from collections.abc import Sequence
7from contextlib import ExitStack
8from dataclasses import dataclass, field
9from functools import cached_property
10from pathlib import Path
11from time import perf_counter
12from typing import TYPE_CHECKING, Annotated, Any, Callable, Generic, TypeVar
14import logfire_api
15import pydantic
16import typing_extensions
18from . import _utils, exceptions, mermaid
19from .nodes import BaseNode, DepsT, End, GraphRunContext, NodeDef, RunEndT
20from .state import EndStep, HistoryStep, NodeStep, StateT, deep_copy_state, nodes_schema_var
22# while waiting for https://github.com/pydantic/logfire/issues/745
23try:
24 import logfire._internal.stack_info
25except ImportError:
26 pass
27else:
28 from pathlib import Path
30 logfire._internal.stack_info.NON_USER_CODE_PREFIXES += (str(Path(__file__).parent.absolute()),)
33__all__ = ('Graph',)
35_logfire = logfire_api.Logfire(otel_scope='pydantic-graph')
37T = TypeVar('T')
38"""An invariant typevar."""
41@dataclass(init=False)
42class Graph(Generic[StateT, DepsT, RunEndT]):
43 """Definition of a graph.
45 In `pydantic-graph`, a graph is a collection of nodes that can be run in sequence. The nodes define
46 their outgoing edges — e.g. which nodes may be run next, and thereby the structure of the graph.
48 Here's a very simple example of a graph which increments a number by 1, but makes sure the number is never
49 42 at the end.
51 ```py {title="never_42.py" noqa="I001" py="3.10"}
52 from __future__ import annotations
54 from dataclasses import dataclass
56 from pydantic_graph import BaseNode, End, Graph, GraphRunContext
58 @dataclass
59 class MyState:
60 number: int
62 @dataclass
63 class Increment(BaseNode[MyState]):
64 async def run(self, ctx: GraphRunContext) -> Check42:
65 ctx.state.number += 1
66 return Check42()
68 @dataclass
69 class Check42(BaseNode[MyState, None, int]):
70 async def run(self, ctx: GraphRunContext) -> Increment | End[int]:
71 if ctx.state.number == 42:
72 return Increment()
73 else:
74 return End(ctx.state.number)
76 never_42_graph = Graph(nodes=(Increment, Check42))
77 ```
78 _(This example is complete, it can be run "as is")_
80 See [`run`][pydantic_graph.graph.Graph.run] For an example of running graph, and
81 [`mermaid_code`][pydantic_graph.graph.Graph.mermaid_code] for an example of generating a mermaid diagram
82 from the graph.
83 """
85 name: str | None
86 node_defs: dict[str, NodeDef[StateT, DepsT, RunEndT]]
87 snapshot_state: Callable[[StateT], StateT]
88 _state_type: type[StateT] | _utils.Unset = field(repr=False)
89 _run_end_type: type[RunEndT] | _utils.Unset = field(repr=False)
90 _auto_instrument: bool = field(repr=False)
92 def __init__(
93 self,
94 *,
95 nodes: Sequence[type[BaseNode[StateT, DepsT, RunEndT]]],
96 name: str | None = None,
97 state_type: type[StateT] | _utils.Unset = _utils.UNSET,
98 run_end_type: type[RunEndT] | _utils.Unset = _utils.UNSET,
99 snapshot_state: Callable[[StateT], StateT] = deep_copy_state,
100 auto_instrument: bool = True,
101 ):
102 """Create a graph from a sequence of nodes.
104 Args:
105 nodes: The nodes which make up the graph, nodes need to be unique and all be generic in the same
106 state type.
107 name: Optional name for the graph, if not provided the name will be inferred from the calling frame
108 on the first call to a graph method.
109 state_type: The type of the state for the graph, this can generally be inferred from `nodes`.
110 run_end_type: The type of the result of running the graph, this can generally be inferred from `nodes`.
111 snapshot_state: A function to snapshot the state of the graph, this is used in
112 [`NodeStep`][pydantic_graph.state.NodeStep] and [`EndStep`][pydantic_graph.state.EndStep] to record
113 the state before each step.
114 auto_instrument: Whether to create a span for the graph run and the execution of each node's run method.
115 """
116 self.name = name
117 self._state_type = state_type
118 self._run_end_type = run_end_type
119 self._auto_instrument = auto_instrument
120 self.snapshot_state = snapshot_state
122 parent_namespace = _utils.get_parent_namespace(inspect.currentframe())
123 self.node_defs: dict[str, NodeDef[StateT, DepsT, RunEndT]] = {}
124 for node in nodes:
125 self._register_node(node, parent_namespace)
127 self._validate_edges()
129 async def run(
130 self: Graph[StateT, DepsT, T],
131 start_node: BaseNode[StateT, DepsT, T],
132 *,
133 state: StateT = None,
134 deps: DepsT = None,
135 infer_name: bool = True,
136 ) -> tuple[T, list[HistoryStep[StateT, T]]]:
137 """Run the graph from a starting node until it ends.
139 Args:
140 start_node: the first node to run, since the graph definition doesn't define the entry point in the graph,
141 you need to provide the starting node.
142 state: The initial state of the graph.
143 deps: The dependencies of the graph.
144 infer_name: Whether to infer the graph name from the calling frame.
146 Returns:
147 The result type from ending the run and the history of the run.
149 Here's an example of running the graph from [above][pydantic_graph.graph.Graph]:
151 ```py {title="run_never_42.py" noqa="I001" py="3.10"}
152 from never_42 import Increment, MyState, never_42_graph
154 async def main():
155 state = MyState(1)
156 _, history = await never_42_graph.run(Increment(), state=state)
157 print(state)
158 #> MyState(number=2)
159 print(len(history))
160 #> 3
162 state = MyState(41)
163 _, history = await never_42_graph.run(Increment(), state=state)
164 print(state)
165 #> MyState(number=43)
166 print(len(history))
167 #> 5
168 ```
169 """
170 if infer_name and self.name is None:
171 self._infer_name(inspect.currentframe())
173 history: list[HistoryStep[StateT, T]] = []
174 with ExitStack() as stack:
175 run_span: logfire_api.LogfireSpan | None = None
176 if self._auto_instrument:
177 run_span = stack.enter_context(
178 _logfire.span(
179 '{graph_name} run {start=}',
180 graph_name=self.name or 'graph',
181 start=start_node,
182 )
183 )
184 while True:
185 next_node = await self.next(start_node, history, state=state, deps=deps, infer_name=False)
186 if isinstance(next_node, End):
187 history.append(EndStep(result=next_node))
188 if run_span is not None:
189 run_span.set_attribute('history', history)
190 return next_node.data, history
191 elif isinstance(next_node, BaseNode):
192 start_node = next_node
193 else:
194 if TYPE_CHECKING:
195 typing_extensions.assert_never(next_node)
196 else:
197 raise exceptions.GraphRuntimeError(
198 f'Invalid node return type: `{type(next_node).__name__}`. Expected `BaseNode` or `End`.'
199 )
201 def run_sync(
202 self: Graph[StateT, DepsT, T],
203 start_node: BaseNode[StateT, DepsT, T],
204 *,
205 state: StateT = None,
206 deps: DepsT = None,
207 infer_name: bool = True,
208 ) -> tuple[T, list[HistoryStep[StateT, T]]]:
209 """Run the graph synchronously.
211 This is a convenience method that wraps [`self.run`][pydantic_graph.Graph.run] with `loop.run_until_complete(...)`.
212 You therefore can't use this method inside async code or if there's an active event loop.
214 Args:
215 start_node: the first node to run, since the graph definition doesn't define the entry point in the graph,
216 you need to provide the starting node.
217 state: The initial state of the graph.
218 deps: The dependencies of the graph.
219 infer_name: Whether to infer the graph name from the calling frame.
221 Returns:
222 The result type from ending the run and the history of the run.
223 """
224 if infer_name and self.name is None: 224 ↛ 226line 224 didn't jump to line 226 because the condition on line 224 was always true
225 self._infer_name(inspect.currentframe())
226 return asyncio.get_event_loop().run_until_complete(
227 self.run(start_node, state=state, deps=deps, infer_name=False)
228 )
230 async def next(
231 self: Graph[StateT, DepsT, T],
232 node: BaseNode[StateT, DepsT, T],
233 history: list[HistoryStep[StateT, T]],
234 *,
235 state: StateT = None,
236 deps: DepsT = None,
237 infer_name: bool = True,
238 ) -> BaseNode[StateT, DepsT, Any] | End[T]:
239 """Run a node in the graph and return the next node to run.
241 Args:
242 node: The node to run.
243 history: The history of the graph run so far. NOTE: this will be mutated to add the new step.
244 state: The current state of the graph.
245 deps: The dependencies of the graph.
246 infer_name: Whether to infer the graph name from the calling frame.
248 Returns:
249 The next node to run or [`End`][pydantic_graph.nodes.End] if the graph has finished.
250 """
251 if infer_name and self.name is None:
252 self._infer_name(inspect.currentframe())
253 node_id = node.get_id()
254 if node_id not in self.node_defs:
255 raise exceptions.GraphRuntimeError(f'Node `{node}` is not in the graph.')
257 with ExitStack() as stack:
258 if self._auto_instrument:
259 stack.enter_context(_logfire.span('run node {node_id}', node_id=node_id, node=node))
260 ctx = GraphRunContext(state, deps)
261 start_ts = _utils.now_utc()
262 start = perf_counter()
263 next_node = await node.run(ctx)
264 duration = perf_counter() - start
266 history.append(
267 NodeStep(state=state, node=node, start_ts=start_ts, duration=duration, snapshot_state=self.snapshot_state)
268 )
269 return next_node
271 def dump_history(
272 self: Graph[StateT, DepsT, T], history: list[HistoryStep[StateT, T]], *, indent: int | None = None
273 ) -> bytes:
274 """Dump the history of a graph run as JSON.
276 Args:
277 history: The history of the graph run.
278 indent: The number of spaces to indent the JSON.
280 Returns:
281 The JSON representation of the history.
282 """
283 return self.history_type_adapter.dump_json(history, indent=indent)
285 def load_history(self, json_bytes: str | bytes | bytearray) -> list[HistoryStep[StateT, RunEndT]]:
286 """Load the history of a graph run from JSON.
288 Args:
289 json_bytes: The JSON representation of the history.
291 Returns:
292 The history of the graph run.
293 """
294 return self.history_type_adapter.validate_json(json_bytes)
296 @cached_property
297 def history_type_adapter(self) -> pydantic.TypeAdapter[list[HistoryStep[StateT, RunEndT]]]:
298 nodes = [node_def.node for node_def in self.node_defs.values()]
299 state_t = self._get_state_type()
300 end_t = self._get_run_end_type()
301 token = nodes_schema_var.set(nodes)
302 try:
303 ta = pydantic.TypeAdapter(list[Annotated[HistoryStep[state_t, end_t], pydantic.Discriminator('kind')]])
304 finally:
305 nodes_schema_var.reset(token)
306 return ta
308 def mermaid_code(
309 self,
310 *,
311 start_node: Sequence[mermaid.NodeIdent] | mermaid.NodeIdent | None = None,
312 title: str | None | typing_extensions.Literal[False] = None,
313 edge_labels: bool = True,
314 notes: bool = True,
315 highlighted_nodes: Sequence[mermaid.NodeIdent] | mermaid.NodeIdent | None = None,
316 highlight_css: str = mermaid.DEFAULT_HIGHLIGHT_CSS,
317 infer_name: bool = True,
318 direction: mermaid.StateDiagramDirection | None = None,
319 ) -> str:
320 """Generate a diagram representing the graph as [mermaid](https://mermaid.js.org/) diagram.
322 This method calls [`pydantic_graph.mermaid.generate_code`][pydantic_graph.mermaid.generate_code].
324 Args:
325 start_node: The node or nodes which can start the graph.
326 title: The title of the diagram, use `False` to not include a title.
327 edge_labels: Whether to include edge labels.
328 notes: Whether to include notes on each node.
329 highlighted_nodes: Optional node or nodes to highlight.
330 highlight_css: The CSS to use for highlighting nodes.
331 infer_name: Whether to infer the graph name from the calling frame.
332 direction: The direction of flow.
334 Returns:
335 The mermaid code for the graph, which can then be rendered as a diagram.
337 Here's an example of generating a diagram for the graph from [above][pydantic_graph.graph.Graph]:
339 ```py {title="never_42.py" py="3.10"}
340 from never_42 import Increment, never_42_graph
342 print(never_42_graph.mermaid_code(start_node=Increment))
343 '''
344 ---
345 title: never_42_graph
346 ---
347 stateDiagram-v2
348 [*] --> Increment
349 Increment --> Check42
350 Check42 --> Increment
351 Check42 --> [*]
352 '''
353 ```
355 The rendered diagram will look like this:
357 ```mermaid
358 ---
359 title: never_42_graph
360 ---
361 stateDiagram-v2
362 [*] --> Increment
363 Increment --> Check42
364 Check42 --> Increment
365 Check42 --> [*]
366 ```
367 """
368 if infer_name and self.name is None:
369 self._infer_name(inspect.currentframe())
370 if title is None and self.name:
371 title = self.name
372 return mermaid.generate_code(
373 self,
374 start_node=start_node,
375 highlighted_nodes=highlighted_nodes,
376 highlight_css=highlight_css,
377 title=title or None,
378 edge_labels=edge_labels,
379 notes=notes,
380 direction=direction,
381 )
383 def mermaid_image(
384 self, infer_name: bool = True, **kwargs: typing_extensions.Unpack[mermaid.MermaidConfig]
385 ) -> bytes:
386 """Generate a diagram representing the graph as an image.
388 The format and diagram can be customized using `kwargs`,
389 see [`pydantic_graph.mermaid.MermaidConfig`][pydantic_graph.mermaid.MermaidConfig].
391 !!! note "Uses external service"
392 This method makes a request to [mermaid.ink](https://mermaid.ink) to render the image, `mermaid.ink`
393 is a free service not affiliated with Pydantic.
395 Args:
396 infer_name: Whether to infer the graph name from the calling frame.
397 **kwargs: Additional arguments to pass to `mermaid.request_image`.
399 Returns:
400 The image bytes.
401 """
402 if infer_name and self.name is None:
403 self._infer_name(inspect.currentframe())
404 if 'title' not in kwargs and self.name:
405 kwargs['title'] = self.name
406 return mermaid.request_image(self, **kwargs)
408 def mermaid_save(
409 self, path: Path | str, /, *, infer_name: bool = True, **kwargs: typing_extensions.Unpack[mermaid.MermaidConfig]
410 ) -> None:
411 """Generate a diagram representing the graph and save it as an image.
413 The format and diagram can be customized using `kwargs`,
414 see [`pydantic_graph.mermaid.MermaidConfig`][pydantic_graph.mermaid.MermaidConfig].
416 !!! note "Uses external service"
417 This method makes a request to [mermaid.ink](https://mermaid.ink) to render the image, `mermaid.ink`
418 is a free service not affiliated with Pydantic.
420 Args:
421 path: The path to save the image to.
422 infer_name: Whether to infer the graph name from the calling frame.
423 **kwargs: Additional arguments to pass to `mermaid.save_image`.
424 """
425 if infer_name and self.name is None:
426 self._infer_name(inspect.currentframe())
427 if 'title' not in kwargs and self.name:
428 kwargs['title'] = self.name
429 mermaid.save_image(path, self, **kwargs)
431 def _get_state_type(self) -> type[StateT]:
432 if _utils.is_set(self._state_type):
433 return self._state_type
435 for node_def in self.node_defs.values():
436 for base in typing_extensions.get_original_bases(node_def.node):
437 if typing_extensions.get_origin(base) is BaseNode:
438 args = typing_extensions.get_args(base)
439 if args: 439 ↛ 442line 439 didn't jump to line 442 because the condition on line 439 was always true
440 return args[0]
441 # break the inner (bases) loop
442 break
443 # state defaults to None, so use that if we can't infer it
444 return type(None) # pyright: ignore[reportReturnType]
446 def _get_run_end_type(self) -> type[RunEndT]:
447 if _utils.is_set(self._run_end_type):
448 return self._run_end_type
450 for node_def in self.node_defs.values():
451 for base in typing_extensions.get_original_bases(node_def.node):
452 if typing_extensions.get_origin(base) is BaseNode:
453 args = typing_extensions.get_args(base)
454 if len(args) == 3:
455 t = args[2]
456 if not _utils.is_never(t):
457 return t
458 # break the inner (bases) loop
459 break
460 raise exceptions.GraphSetupError('Could not infer run end type from nodes, please set `run_end_type`.')
462 def _register_node(
463 self: Graph[StateT, DepsT, T],
464 node: type[BaseNode[StateT, DepsT, T]],
465 parent_namespace: dict[str, Any] | None,
466 ) -> None:
467 node_id = node.get_id()
468 if existing_node := self.node_defs.get(node_id):
469 raise exceptions.GraphSetupError(
470 f'Node ID `{node_id}` is not unique — found on {existing_node.node} and {node}'
471 )
472 else:
473 self.node_defs[node_id] = node.get_node_def(parent_namespace)
475 def _validate_edges(self):
476 known_node_ids = self.node_defs.keys()
477 bad_edges: dict[str, list[str]] = {}
479 for node_id, node_def in self.node_defs.items():
480 for edge in node_def.next_node_edges.keys():
481 if edge not in known_node_ids:
482 bad_edges.setdefault(edge, []).append(f'`{node_id}`')
484 if bad_edges:
485 bad_edges_list = [f'`{k}` is referenced by {_utils.comma_and(v)}' for k, v in bad_edges.items()]
486 if len(bad_edges_list) == 1:
487 raise exceptions.GraphSetupError(f'{bad_edges_list[0]} but not included in the graph.')
488 else:
489 b = '\n'.join(f' {be}' for be in bad_edges_list)
490 raise exceptions.GraphSetupError(
491 f'Nodes are referenced in the graph but not included in the graph:\n{b}'
492 )
494 def _infer_name(self, function_frame: types.FrameType | None) -> None:
495 """Infer the agent name from the call frame.
497 Usage should be `self._infer_name(inspect.currentframe())`.
499 Copied from `Agent`.
500 """
501 assert self.name is None, 'Name already set'
502 if function_frame is not None and (parent_frame := function_frame.f_back): # pragma: no branch
503 for name, item in parent_frame.f_locals.items():
504 if item is self:
505 self.name = name
506 return
507 if parent_frame.f_locals != parent_frame.f_globals: 507 ↛ exitline 507 didn't return from function '_infer_name' because the condition on line 507 was always true
508 # if we couldn't find the agent in locals and globals are a different dict, try globals
509 for name, item in parent_frame.f_globals.items(): 509 ↛ exitline 509 didn't return from function '_infer_name' because the loop on line 509 didn't complete
510 if item is self:
511 self.name = name
512 return