Coverage for pydantic_graph/pydantic_graph/graph.py: 97.97%

170 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2025-01-30 19:21 +0000

1from __future__ import annotations as _annotations 

2 

3import asyncio 

4import inspect 

5import types 

6from collections.abc import Sequence 

7from contextlib import ExitStack 

8from dataclasses import dataclass, field 

9from functools import cached_property 

10from pathlib import Path 

11from time import perf_counter 

12from typing import TYPE_CHECKING, Annotated, Any, Callable, Generic, TypeVar 

13 

14import logfire_api 

15import pydantic 

16import typing_extensions 

17 

18from . import _utils, exceptions, mermaid 

19from .nodes import BaseNode, DepsT, End, GraphRunContext, NodeDef, RunEndT 

20from .state import EndStep, HistoryStep, NodeStep, StateT, deep_copy_state, nodes_schema_var 

21 

22# while waiting for https://github.com/pydantic/logfire/issues/745 

23try: 

24 import logfire._internal.stack_info 

25except ImportError: 

26 pass 

27else: 

28 from pathlib import Path 

29 

30 logfire._internal.stack_info.NON_USER_CODE_PREFIXES += (str(Path(__file__).parent.absolute()),) 

31 

32 

33__all__ = ('Graph',) 

34 

35_logfire = logfire_api.Logfire(otel_scope='pydantic-graph') 

36 

37T = TypeVar('T') 

38"""An invariant typevar.""" 

39 

40 

41@dataclass(init=False) 

42class Graph(Generic[StateT, DepsT, RunEndT]): 

43 """Definition of a graph. 

44 

45 In `pydantic-graph`, a graph is a collection of nodes that can be run in sequence. The nodes define 

46 their outgoing edges — e.g. which nodes may be run next, and thereby the structure of the graph. 

47 

48 Here's a very simple example of a graph which increments a number by 1, but makes sure the number is never 

49 42 at the end. 

50 

51 ```py {title="never_42.py" noqa="I001" py="3.10"} 

52 from __future__ import annotations 

53 

54 from dataclasses import dataclass 

55 

56 from pydantic_graph import BaseNode, End, Graph, GraphRunContext 

57 

58 @dataclass 

59 class MyState: 

60 number: int 

61 

62 @dataclass 

63 class Increment(BaseNode[MyState]): 

64 async def run(self, ctx: GraphRunContext) -> Check42: 

65 ctx.state.number += 1 

66 return Check42() 

67 

68 @dataclass 

69 class Check42(BaseNode[MyState, None, int]): 

70 async def run(self, ctx: GraphRunContext) -> Increment | End[int]: 

71 if ctx.state.number == 42: 

72 return Increment() 

73 else: 

74 return End(ctx.state.number) 

75 

76 never_42_graph = Graph(nodes=(Increment, Check42)) 

77 ``` 

78 _(This example is complete, it can be run "as is")_ 

79 

80 See [`run`][pydantic_graph.graph.Graph.run] For an example of running graph, and 

81 [`mermaid_code`][pydantic_graph.graph.Graph.mermaid_code] for an example of generating a mermaid diagram 

82 from the graph. 

83 """ 

84 

85 name: str | None 

86 node_defs: dict[str, NodeDef[StateT, DepsT, RunEndT]] 

87 snapshot_state: Callable[[StateT], StateT] 

88 _state_type: type[StateT] | _utils.Unset = field(repr=False) 

89 _run_end_type: type[RunEndT] | _utils.Unset = field(repr=False) 

90 _auto_instrument: bool = field(repr=False) 

91 

92 def __init__( 

93 self, 

94 *, 

95 nodes: Sequence[type[BaseNode[StateT, DepsT, RunEndT]]], 

96 name: str | None = None, 

97 state_type: type[StateT] | _utils.Unset = _utils.UNSET, 

98 run_end_type: type[RunEndT] | _utils.Unset = _utils.UNSET, 

99 snapshot_state: Callable[[StateT], StateT] = deep_copy_state, 

100 auto_instrument: bool = True, 

101 ): 

102 """Create a graph from a sequence of nodes. 

103 

104 Args: 

105 nodes: The nodes which make up the graph, nodes need to be unique and all be generic in the same 

106 state type. 

107 name: Optional name for the graph, if not provided the name will be inferred from the calling frame 

108 on the first call to a graph method. 

109 state_type: The type of the state for the graph, this can generally be inferred from `nodes`. 

110 run_end_type: The type of the result of running the graph, this can generally be inferred from `nodes`. 

111 snapshot_state: A function to snapshot the state of the graph, this is used in 

112 [`NodeStep`][pydantic_graph.state.NodeStep] and [`EndStep`][pydantic_graph.state.EndStep] to record 

113 the state before each step. 

114 auto_instrument: Whether to create a span for the graph run and the execution of each node's run method. 

115 """ 

116 self.name = name 

117 self._state_type = state_type 

118 self._run_end_type = run_end_type 

119 self._auto_instrument = auto_instrument 

120 self.snapshot_state = snapshot_state 

121 

122 parent_namespace = _utils.get_parent_namespace(inspect.currentframe()) 

123 self.node_defs: dict[str, NodeDef[StateT, DepsT, RunEndT]] = {} 

124 for node in nodes: 

125 self._register_node(node, parent_namespace) 

126 

127 self._validate_edges() 

128 

129 async def run( 

130 self: Graph[StateT, DepsT, T], 

131 start_node: BaseNode[StateT, DepsT, T], 

132 *, 

133 state: StateT = None, 

134 deps: DepsT = None, 

135 infer_name: bool = True, 

136 ) -> tuple[T, list[HistoryStep[StateT, T]]]: 

137 """Run the graph from a starting node until it ends. 

138 

139 Args: 

140 start_node: the first node to run, since the graph definition doesn't define the entry point in the graph, 

141 you need to provide the starting node. 

142 state: The initial state of the graph. 

143 deps: The dependencies of the graph. 

144 infer_name: Whether to infer the graph name from the calling frame. 

145 

146 Returns: 

147 The result type from ending the run and the history of the run. 

148 

149 Here's an example of running the graph from [above][pydantic_graph.graph.Graph]: 

150 

151 ```py {title="run_never_42.py" noqa="I001" py="3.10"} 

152 from never_42 import Increment, MyState, never_42_graph 

153 

154 async def main(): 

155 state = MyState(1) 

156 _, history = await never_42_graph.run(Increment(), state=state) 

157 print(state) 

158 #> MyState(number=2) 

159 print(len(history)) 

160 #> 3 

161 

162 state = MyState(41) 

163 _, history = await never_42_graph.run(Increment(), state=state) 

164 print(state) 

165 #> MyState(number=43) 

166 print(len(history)) 

167 #> 5 

168 ``` 

169 """ 

170 if infer_name and self.name is None: 

171 self._infer_name(inspect.currentframe()) 

172 

173 history: list[HistoryStep[StateT, T]] = [] 

174 with ExitStack() as stack: 

175 run_span: logfire_api.LogfireSpan | None = None 

176 if self._auto_instrument: 

177 run_span = stack.enter_context( 

178 _logfire.span( 

179 '{graph_name} run {start=}', 

180 graph_name=self.name or 'graph', 

181 start=start_node, 

182 ) 

183 ) 

184 while True: 

185 next_node = await self.next(start_node, history, state=state, deps=deps, infer_name=False) 

186 if isinstance(next_node, End): 

187 history.append(EndStep(result=next_node)) 

188 if run_span is not None: 

189 run_span.set_attribute('history', history) 

190 return next_node.data, history 

191 elif isinstance(next_node, BaseNode): 

192 start_node = next_node 

193 else: 

194 if TYPE_CHECKING: 

195 typing_extensions.assert_never(next_node) 

196 else: 

197 raise exceptions.GraphRuntimeError( 

198 f'Invalid node return type: `{type(next_node).__name__}`. Expected `BaseNode` or `End`.' 

199 ) 

200 

201 def run_sync( 

202 self: Graph[StateT, DepsT, T], 

203 start_node: BaseNode[StateT, DepsT, T], 

204 *, 

205 state: StateT = None, 

206 deps: DepsT = None, 

207 infer_name: bool = True, 

208 ) -> tuple[T, list[HistoryStep[StateT, T]]]: 

209 """Run the graph synchronously. 

210 

211 This is a convenience method that wraps [`self.run`][pydantic_graph.Graph.run] with `loop.run_until_complete(...)`. 

212 You therefore can't use this method inside async code or if there's an active event loop. 

213 

214 Args: 

215 start_node: the first node to run, since the graph definition doesn't define the entry point in the graph, 

216 you need to provide the starting node. 

217 state: The initial state of the graph. 

218 deps: The dependencies of the graph. 

219 infer_name: Whether to infer the graph name from the calling frame. 

220 

221 Returns: 

222 The result type from ending the run and the history of the run. 

223 """ 

224 if infer_name and self.name is None: 224 ↛ 226line 224 didn't jump to line 226 because the condition on line 224 was always true

225 self._infer_name(inspect.currentframe()) 

226 return asyncio.get_event_loop().run_until_complete( 

227 self.run(start_node, state=state, deps=deps, infer_name=False) 

228 ) 

229 

230 async def next( 

231 self: Graph[StateT, DepsT, T], 

232 node: BaseNode[StateT, DepsT, T], 

233 history: list[HistoryStep[StateT, T]], 

234 *, 

235 state: StateT = None, 

236 deps: DepsT = None, 

237 infer_name: bool = True, 

238 ) -> BaseNode[StateT, DepsT, Any] | End[T]: 

239 """Run a node in the graph and return the next node to run. 

240 

241 Args: 

242 node: The node to run. 

243 history: The history of the graph run so far. NOTE: this will be mutated to add the new step. 

244 state: The current state of the graph. 

245 deps: The dependencies of the graph. 

246 infer_name: Whether to infer the graph name from the calling frame. 

247 

248 Returns: 

249 The next node to run or [`End`][pydantic_graph.nodes.End] if the graph has finished. 

250 """ 

251 if infer_name and self.name is None: 

252 self._infer_name(inspect.currentframe()) 

253 node_id = node.get_id() 

254 if node_id not in self.node_defs: 

255 raise exceptions.GraphRuntimeError(f'Node `{node}` is not in the graph.') 

256 

257 with ExitStack() as stack: 

258 if self._auto_instrument: 

259 stack.enter_context(_logfire.span('run node {node_id}', node_id=node_id, node=node)) 

260 ctx = GraphRunContext(state, deps) 

261 start_ts = _utils.now_utc() 

262 start = perf_counter() 

263 next_node = await node.run(ctx) 

264 duration = perf_counter() - start 

265 

266 history.append( 

267 NodeStep(state=state, node=node, start_ts=start_ts, duration=duration, snapshot_state=self.snapshot_state) 

268 ) 

269 return next_node 

270 

271 def dump_history( 

272 self: Graph[StateT, DepsT, T], history: list[HistoryStep[StateT, T]], *, indent: int | None = None 

273 ) -> bytes: 

274 """Dump the history of a graph run as JSON. 

275 

276 Args: 

277 history: The history of the graph run. 

278 indent: The number of spaces to indent the JSON. 

279 

280 Returns: 

281 The JSON representation of the history. 

282 """ 

283 return self.history_type_adapter.dump_json(history, indent=indent) 

284 

285 def load_history(self, json_bytes: str | bytes | bytearray) -> list[HistoryStep[StateT, RunEndT]]: 

286 """Load the history of a graph run from JSON. 

287 

288 Args: 

289 json_bytes: The JSON representation of the history. 

290 

291 Returns: 

292 The history of the graph run. 

293 """ 

294 return self.history_type_adapter.validate_json(json_bytes) 

295 

296 @cached_property 

297 def history_type_adapter(self) -> pydantic.TypeAdapter[list[HistoryStep[StateT, RunEndT]]]: 

298 nodes = [node_def.node for node_def in self.node_defs.values()] 

299 state_t = self._get_state_type() 

300 end_t = self._get_run_end_type() 

301 token = nodes_schema_var.set(nodes) 

302 try: 

303 ta = pydantic.TypeAdapter(list[Annotated[HistoryStep[state_t, end_t], pydantic.Discriminator('kind')]]) 

304 finally: 

305 nodes_schema_var.reset(token) 

306 return ta 

307 

308 def mermaid_code( 

309 self, 

310 *, 

311 start_node: Sequence[mermaid.NodeIdent] | mermaid.NodeIdent | None = None, 

312 title: str | None | typing_extensions.Literal[False] = None, 

313 edge_labels: bool = True, 

314 notes: bool = True, 

315 highlighted_nodes: Sequence[mermaid.NodeIdent] | mermaid.NodeIdent | None = None, 

316 highlight_css: str = mermaid.DEFAULT_HIGHLIGHT_CSS, 

317 infer_name: bool = True, 

318 direction: mermaid.StateDiagramDirection | None = None, 

319 ) -> str: 

320 """Generate a diagram representing the graph as [mermaid](https://mermaid.js.org/) diagram. 

321 

322 This method calls [`pydantic_graph.mermaid.generate_code`][pydantic_graph.mermaid.generate_code]. 

323 

324 Args: 

325 start_node: The node or nodes which can start the graph. 

326 title: The title of the diagram, use `False` to not include a title. 

327 edge_labels: Whether to include edge labels. 

328 notes: Whether to include notes on each node. 

329 highlighted_nodes: Optional node or nodes to highlight. 

330 highlight_css: The CSS to use for highlighting nodes. 

331 infer_name: Whether to infer the graph name from the calling frame. 

332 direction: The direction of flow. 

333 

334 Returns: 

335 The mermaid code for the graph, which can then be rendered as a diagram. 

336 

337 Here's an example of generating a diagram for the graph from [above][pydantic_graph.graph.Graph]: 

338 

339 ```py {title="never_42.py" py="3.10"} 

340 from never_42 import Increment, never_42_graph 

341 

342 print(never_42_graph.mermaid_code(start_node=Increment)) 

343 ''' 

344 --- 

345 title: never_42_graph 

346 --- 

347 stateDiagram-v2 

348 [*] --> Increment 

349 Increment --> Check42 

350 Check42 --> Increment 

351 Check42 --> [*] 

352 ''' 

353 ``` 

354 

355 The rendered diagram will look like this: 

356 

357 ```mermaid 

358 --- 

359 title: never_42_graph 

360 --- 

361 stateDiagram-v2 

362 [*] --> Increment 

363 Increment --> Check42 

364 Check42 --> Increment 

365 Check42 --> [*] 

366 ``` 

367 """ 

368 if infer_name and self.name is None: 

369 self._infer_name(inspect.currentframe()) 

370 if title is None and self.name: 

371 title = self.name 

372 return mermaid.generate_code( 

373 self, 

374 start_node=start_node, 

375 highlighted_nodes=highlighted_nodes, 

376 highlight_css=highlight_css, 

377 title=title or None, 

378 edge_labels=edge_labels, 

379 notes=notes, 

380 direction=direction, 

381 ) 

382 

383 def mermaid_image( 

384 self, infer_name: bool = True, **kwargs: typing_extensions.Unpack[mermaid.MermaidConfig] 

385 ) -> bytes: 

386 """Generate a diagram representing the graph as an image. 

387 

388 The format and diagram can be customized using `kwargs`, 

389 see [`pydantic_graph.mermaid.MermaidConfig`][pydantic_graph.mermaid.MermaidConfig]. 

390 

391 !!! note "Uses external service" 

392 This method makes a request to [mermaid.ink](https://mermaid.ink) to render the image, `mermaid.ink` 

393 is a free service not affiliated with Pydantic. 

394 

395 Args: 

396 infer_name: Whether to infer the graph name from the calling frame. 

397 **kwargs: Additional arguments to pass to `mermaid.request_image`. 

398 

399 Returns: 

400 The image bytes. 

401 """ 

402 if infer_name and self.name is None: 

403 self._infer_name(inspect.currentframe()) 

404 if 'title' not in kwargs and self.name: 

405 kwargs['title'] = self.name 

406 return mermaid.request_image(self, **kwargs) 

407 

408 def mermaid_save( 

409 self, path: Path | str, /, *, infer_name: bool = True, **kwargs: typing_extensions.Unpack[mermaid.MermaidConfig] 

410 ) -> None: 

411 """Generate a diagram representing the graph and save it as an image. 

412 

413 The format and diagram can be customized using `kwargs`, 

414 see [`pydantic_graph.mermaid.MermaidConfig`][pydantic_graph.mermaid.MermaidConfig]. 

415 

416 !!! note "Uses external service" 

417 This method makes a request to [mermaid.ink](https://mermaid.ink) to render the image, `mermaid.ink` 

418 is a free service not affiliated with Pydantic. 

419 

420 Args: 

421 path: The path to save the image to. 

422 infer_name: Whether to infer the graph name from the calling frame. 

423 **kwargs: Additional arguments to pass to `mermaid.save_image`. 

424 """ 

425 if infer_name and self.name is None: 

426 self._infer_name(inspect.currentframe()) 

427 if 'title' not in kwargs and self.name: 

428 kwargs['title'] = self.name 

429 mermaid.save_image(path, self, **kwargs) 

430 

431 def _get_state_type(self) -> type[StateT]: 

432 if _utils.is_set(self._state_type): 

433 return self._state_type 

434 

435 for node_def in self.node_defs.values(): 

436 for base in typing_extensions.get_original_bases(node_def.node): 

437 if typing_extensions.get_origin(base) is BaseNode: 

438 args = typing_extensions.get_args(base) 

439 if args: 439 ↛ 442line 439 didn't jump to line 442 because the condition on line 439 was always true

440 return args[0] 

441 # break the inner (bases) loop 

442 break 

443 # state defaults to None, so use that if we can't infer it 

444 return type(None) # pyright: ignore[reportReturnType] 

445 

446 def _get_run_end_type(self) -> type[RunEndT]: 

447 if _utils.is_set(self._run_end_type): 

448 return self._run_end_type 

449 

450 for node_def in self.node_defs.values(): 

451 for base in typing_extensions.get_original_bases(node_def.node): 

452 if typing_extensions.get_origin(base) is BaseNode: 

453 args = typing_extensions.get_args(base) 

454 if len(args) == 3: 

455 t = args[2] 

456 if not _utils.is_never(t): 

457 return t 

458 # break the inner (bases) loop 

459 break 

460 raise exceptions.GraphSetupError('Could not infer run end type from nodes, please set `run_end_type`.') 

461 

462 def _register_node( 

463 self: Graph[StateT, DepsT, T], 

464 node: type[BaseNode[StateT, DepsT, T]], 

465 parent_namespace: dict[str, Any] | None, 

466 ) -> None: 

467 node_id = node.get_id() 

468 if existing_node := self.node_defs.get(node_id): 

469 raise exceptions.GraphSetupError( 

470 f'Node ID `{node_id}` is not unique — found on {existing_node.node} and {node}' 

471 ) 

472 else: 

473 self.node_defs[node_id] = node.get_node_def(parent_namespace) 

474 

475 def _validate_edges(self): 

476 known_node_ids = self.node_defs.keys() 

477 bad_edges: dict[str, list[str]] = {} 

478 

479 for node_id, node_def in self.node_defs.items(): 

480 for edge in node_def.next_node_edges.keys(): 

481 if edge not in known_node_ids: 

482 bad_edges.setdefault(edge, []).append(f'`{node_id}`') 

483 

484 if bad_edges: 

485 bad_edges_list = [f'`{k}` is referenced by {_utils.comma_and(v)}' for k, v in bad_edges.items()] 

486 if len(bad_edges_list) == 1: 

487 raise exceptions.GraphSetupError(f'{bad_edges_list[0]} but not included in the graph.') 

488 else: 

489 b = '\n'.join(f' {be}' for be in bad_edges_list) 

490 raise exceptions.GraphSetupError( 

491 f'Nodes are referenced in the graph but not included in the graph:\n{b}' 

492 ) 

493 

494 def _infer_name(self, function_frame: types.FrameType | None) -> None: 

495 """Infer the agent name from the call frame. 

496 

497 Usage should be `self._infer_name(inspect.currentframe())`. 

498 

499 Copied from `Agent`. 

500 """ 

501 assert self.name is None, 'Name already set' 

502 if function_frame is not None and (parent_frame := function_frame.f_back): # pragma: no branch 

503 for name, item in parent_frame.f_locals.items(): 

504 if item is self: 

505 self.name = name 

506 return 

507 if parent_frame.f_locals != parent_frame.f_globals: 507 ↛ exitline 507 didn't return from function '_infer_name' because the condition on line 507 was always true

508 # if we couldn't find the agent in locals and globals are a different dict, try globals 

509 for name, item in parent_frame.f_globals.items(): 509 ↛ exitline 509 didn't return from function '_infer_name' because the loop on line 509 didn't complete

510 if item is self: 

511 self.name = name 

512 return