Coverage for pydantic_graph/pydantic_graph/graph.py: 97.55%

223 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-28 17:27 +0000

1from __future__ import annotations as _annotations 

2 

3import inspect 

4import types 

5from collections.abc import AsyncIterator, Sequence 

6from contextlib import AbstractContextManager, ExitStack, asynccontextmanager 

7from dataclasses import dataclass, field 

8from functools import cached_property 

9from typing import Any, Generic, cast 

10 

11import logfire_api 

12import typing_extensions 

13from logfire_api import LogfireSpan 

14from typing_extensions import deprecated 

15from typing_inspection import typing_objects 

16 

17from . import _utils, exceptions, mermaid 

18from .nodes import BaseNode, DepsT, End, GraphRunContext, NodeDef, RunEndT, StateT 

19from .persistence import BaseStatePersistence 

20from .persistence.in_mem import SimpleStatePersistence 

21 

22# while waiting for https://github.com/pydantic/logfire/issues/745 

23try: 

24 import logfire._internal.stack_info 

25except ImportError: 

26 pass 

27else: 

28 from pathlib import Path 

29 

30 logfire._internal.stack_info.NON_USER_CODE_PREFIXES += (str(Path(__file__).parent.absolute()),) 

31 

32 

33__all__ = 'Graph', 'GraphRun', 'GraphRunResult' 

34 

35_logfire = logfire_api.Logfire(otel_scope='pydantic-graph') 

36 

37 

38@dataclass(init=False) 

39class Graph(Generic[StateT, DepsT, RunEndT]): 

40 """Definition of a graph. 

41 

42 In `pydantic-graph`, a graph is a collection of nodes that can be run in sequence. The nodes define 

43 their outgoing edges — e.g. which nodes may be run next, and thereby the structure of the graph. 

44 

45 Here's a very simple example of a graph which increments a number by 1, but makes sure the number is never 

46 42 at the end. 

47 

48 ```py {title="never_42.py" noqa="I001" py="3.10"} 

49 from __future__ import annotations 

50 

51 from dataclasses import dataclass 

52 

53 from pydantic_graph import BaseNode, End, Graph, GraphRunContext 

54 

55 @dataclass 

56 class MyState: 

57 number: int 

58 

59 @dataclass 

60 class Increment(BaseNode[MyState]): 

61 async def run(self, ctx: GraphRunContext) -> Check42: 

62 ctx.state.number += 1 

63 return Check42() 

64 

65 @dataclass 

66 class Check42(BaseNode[MyState, None, int]): 

67 async def run(self, ctx: GraphRunContext) -> Increment | End[int]: 

68 if ctx.state.number == 42: 

69 return Increment() 

70 else: 

71 return End(ctx.state.number) 

72 

73 never_42_graph = Graph(nodes=(Increment, Check42)) 

74 ``` 

75 _(This example is complete, it can be run "as is")_ 

76 

77 See [`run`][pydantic_graph.graph.Graph.run] For an example of running graph, and 

78 [`mermaid_code`][pydantic_graph.graph.Graph.mermaid_code] for an example of generating a mermaid diagram 

79 from the graph. 

80 """ 

81 

82 name: str | None 

83 node_defs: dict[str, NodeDef[StateT, DepsT, RunEndT]] 

84 _state_type: type[StateT] | _utils.Unset = field(repr=False) 

85 _run_end_type: type[RunEndT] | _utils.Unset = field(repr=False) 

86 auto_instrument: bool = field(repr=False) 

87 

88 def __init__( 

89 self, 

90 *, 

91 nodes: Sequence[type[BaseNode[StateT, DepsT, RunEndT]]], 

92 name: str | None = None, 

93 state_type: type[StateT] | _utils.Unset = _utils.UNSET, 

94 run_end_type: type[RunEndT] | _utils.Unset = _utils.UNSET, 

95 auto_instrument: bool = True, 

96 ): 

97 """Create a graph from a sequence of nodes. 

98 

99 Args: 

100 nodes: The nodes which make up the graph, nodes need to be unique and all be generic in the same 

101 state type. 

102 name: Optional name for the graph, if not provided the name will be inferred from the calling frame 

103 on the first call to a graph method. 

104 state_type: The type of the state for the graph, this can generally be inferred from `nodes`. 

105 run_end_type: The type of the result of running the graph, this can generally be inferred from `nodes`. 

106 auto_instrument: Whether to create a span for the graph run and the execution of each node's run method. 

107 """ 

108 self.name = name 

109 self._state_type = state_type 

110 self._run_end_type = run_end_type 

111 self.auto_instrument = auto_instrument 

112 

113 parent_namespace = _utils.get_parent_namespace(inspect.currentframe()) 

114 self.node_defs: dict[str, NodeDef[StateT, DepsT, RunEndT]] = {} 

115 for node in nodes: 

116 self._register_node(node, parent_namespace) 

117 

118 self._validate_edges() 

119 

120 async def run( 

121 self, 

122 start_node: BaseNode[StateT, DepsT, RunEndT], 

123 *, 

124 state: StateT = None, 

125 deps: DepsT = None, 

126 persistence: BaseStatePersistence[StateT, RunEndT] | None = None, 

127 infer_name: bool = True, 

128 span: LogfireSpan | None = None, 

129 ) -> GraphRunResult[StateT, RunEndT]: 

130 """Run the graph from a starting node until it ends. 

131 

132 Args: 

133 start_node: the first node to run, since the graph definition doesn't define the entry point in the graph, 

134 you need to provide the starting node. 

135 state: The initial state of the graph. 

136 deps: The dependencies of the graph. 

137 persistence: State persistence interface, defaults to 

138 [`SimpleStatePersistence`][pydantic_graph.SimpleStatePersistence] if `None`. 

139 infer_name: Whether to infer the graph name from the calling frame. 

140 span: The span to use for the graph run. If not provided, a span will be created depending on the value of 

141 the `auto_instrument` field. 

142 

143 Returns: 

144 A `GraphRunResult` containing information about the run, including its final result. 

145 

146 Here's an example of running the graph from [above][pydantic_graph.graph.Graph]: 

147 

148 ```py {title="run_never_42.py" noqa="I001" py="3.10"} 

149 from never_42 import Increment, MyState, never_42_graph 

150 

151 async def main(): 

152 state = MyState(1) 

153 await never_42_graph.run(Increment(), state=state) 

154 print(state) 

155 #> MyState(number=2) 

156 

157 state = MyState(41) 

158 await never_42_graph.run(Increment(), state=state) 

159 print(state) 

160 #> MyState(number=43) 

161 ``` 

162 """ 

163 if infer_name and self.name is None: 

164 self._infer_name(inspect.currentframe()) 

165 

166 async with self.iter( 

167 start_node, state=state, deps=deps, persistence=persistence, span=span, infer_name=False 

168 ) as graph_run: 

169 async for _node in graph_run: 

170 pass 

171 

172 final_result = graph_run.result 

173 assert final_result is not None, 'GraphRun should have a final result' 

174 return final_result 

175 

176 def run_sync( 

177 self, 

178 start_node: BaseNode[StateT, DepsT, RunEndT], 

179 *, 

180 state: StateT = None, 

181 deps: DepsT = None, 

182 persistence: BaseStatePersistence[StateT, RunEndT] | None = None, 

183 infer_name: bool = True, 

184 ) -> GraphRunResult[StateT, RunEndT]: 

185 """Synchronously run the graph. 

186 

187 This is a convenience method that wraps [`self.run`][pydantic_graph.Graph.run] with `loop.run_until_complete(...)`. 

188 You therefore can't use this method inside async code or if there's an active event loop. 

189 

190 Args: 

191 start_node: the first node to run, since the graph definition doesn't define the entry point in the graph, 

192 you need to provide the starting node. 

193 state: The initial state of the graph. 

194 deps: The dependencies of the graph. 

195 persistence: State persistence interface, defaults to 

196 [`SimpleStatePersistence`][pydantic_graph.SimpleStatePersistence] if `None`. 

197 infer_name: Whether to infer the graph name from the calling frame. 

198 

199 Returns: 

200 The result type from ending the run and the history of the run. 

201 """ 

202 if infer_name and self.name is None: 202 ↛ 205line 202 didn't jump to line 205 because the condition on line 202 was always true

203 self._infer_name(inspect.currentframe()) 

204 

205 return _utils.get_event_loop().run_until_complete( 

206 self.run(start_node, state=state, deps=deps, persistence=persistence, infer_name=False) 

207 ) 

208 

209 @asynccontextmanager 

210 async def iter( 

211 self, 

212 start_node: BaseNode[StateT, DepsT, RunEndT], 

213 *, 

214 state: StateT = None, 

215 deps: DepsT = None, 

216 persistence: BaseStatePersistence[StateT, RunEndT] | None = None, 

217 span: AbstractContextManager[Any] | None = None, 

218 infer_name: bool = True, 

219 ) -> AsyncIterator[GraphRun[StateT, DepsT, RunEndT]]: 

220 """A contextmanager which can be used to iterate over the graph's nodes as they are executed. 

221 

222 This method returns a `GraphRun` object which can be used to async-iterate over the nodes of this `Graph` as 

223 they are executed. This is the API to use if you want to record or interact with the nodes as the graph 

224 execution unfolds. 

225 

226 The `GraphRun` can also be used to manually drive the graph execution by calling 

227 [`GraphRun.next`][pydantic_graph.graph.GraphRun.next]. 

228 

229 The `GraphRun` provides access to the full run history, state, deps, and the final result of the run once 

230 it has completed. 

231 

232 For more details, see the API documentation of [`GraphRun`][pydantic_graph.graph.GraphRun]. 

233 

234 Args: 

235 start_node: the first node to run. Since the graph definition doesn't define the entry point in the graph, 

236 you need to provide the starting node. 

237 state: The initial state of the graph. 

238 deps: The dependencies of the graph. 

239 persistence: State persistence interface, defaults to 

240 [`SimpleStatePersistence`][pydantic_graph.SimpleStatePersistence] if `None`. 

241 span: The span to use for the graph run. If not provided, a new span will be created. 

242 infer_name: Whether to infer the graph name from the calling frame. 

243 

244 Returns: A GraphRun that can be async iterated over to drive the graph to completion. 

245 """ 

246 if infer_name and self.name is None: 

247 # f_back because `asynccontextmanager` adds one frame 

248 if frame := inspect.currentframe(): # pragma: no branch 

249 self._infer_name(frame.f_back) 

250 

251 if persistence is None: 

252 persistence = SimpleStatePersistence() 

253 persistence.set_graph_types(self) 

254 

255 if self.auto_instrument and span is None: 

256 span = logfire_api.span('run graph {graph.name}', graph=self) 

257 

258 with ExitStack() as stack: 

259 if span is not None: 259 ↛ 261line 259 didn't jump to line 261 because the condition on line 259 was always true

260 stack.enter_context(span) 

261 yield GraphRun[StateT, DepsT, RunEndT]( 

262 graph=self, start_node=start_node, persistence=persistence, state=state, deps=deps 

263 ) 

264 

265 @asynccontextmanager 

266 async def iter_from_persistence( 

267 self, 

268 persistence: BaseStatePersistence[StateT, RunEndT], 

269 *, 

270 deps: DepsT = None, 

271 span: AbstractContextManager[Any] | None = None, 

272 infer_name: bool = True, 

273 ) -> AsyncIterator[GraphRun[StateT, DepsT, RunEndT]]: 

274 """A contextmanager to iterate over the graph's nodes as they are executed, created from a persistence object. 

275 

276 This method has similar functionality to [`iter`][pydantic_graph.graph.Graph.iter], 

277 but instead of passing the node to run, it will restore the node and state from state persistence. 

278 

279 Args: 

280 persistence: The state persistence interface to use. 

281 deps: The dependencies of the graph. 

282 span: The span to use for the graph run. If not provided, a new span will be created. 

283 infer_name: Whether to infer the graph name from the calling frame. 

284 

285 Returns: A GraphRun that can be async iterated over to drive the graph to completion. 

286 """ 

287 if infer_name and self.name is None: 

288 # f_back because `asynccontextmanager` adds one frame 

289 if frame := inspect.currentframe(): # pragma: no branch 

290 self._infer_name(frame.f_back) 

291 

292 persistence.set_graph_types(self) 

293 

294 snapshot = await persistence.load_next() 

295 if snapshot is None: 

296 raise exceptions.GraphRuntimeError('Unable to restore snapshot from state persistence.') 

297 

298 snapshot.node.set_snapshot_id(snapshot.id) 

299 

300 if self.auto_instrument and span is None: 300 ↛ 303line 300 didn't jump to line 303 because the condition on line 300 was always true

301 span = logfire_api.span('run graph {graph.name}', graph=self) 

302 

303 with ExitStack() as stack: 

304 if span is not None: 304 ↛ 306line 304 didn't jump to line 306 because the condition on line 304 was always true

305 stack.enter_context(span) 

306 yield GraphRun[StateT, DepsT, RunEndT]( 

307 graph=self, 

308 start_node=snapshot.node, 

309 persistence=persistence, 

310 state=snapshot.state, 

311 deps=deps, 

312 snapshot_id=snapshot.id, 

313 ) 

314 

315 async def initialize( 

316 self, 

317 node: BaseNode[StateT, DepsT, RunEndT], 

318 persistence: BaseStatePersistence[StateT, RunEndT], 

319 *, 

320 state: StateT = None, 

321 infer_name: bool = True, 

322 ) -> None: 

323 """Initialize a new graph run in persistence without running it. 

324 

325 This is useful if you want to set up a graph run to be run later, e.g. via 

326 [`iter_from_persistence`][pydantic_graph.graph.Graph.iter_from_persistence]. 

327 

328 Args: 

329 node: The node to run first. 

330 persistence: State persistence interface. 

331 state: The start state of the graph. 

332 infer_name: Whether to infer the graph name from the calling frame. 

333 """ 

334 if infer_name and self.name is None: 

335 self._infer_name(inspect.currentframe()) 

336 

337 persistence.set_graph_types(self) 

338 await persistence.snapshot_node(state, node) 

339 

340 @deprecated('`next` is deprecated, use `async with graph.iter(...) as run: run.next()` instead') 

341 async def next( 

342 self, 

343 node: BaseNode[StateT, DepsT, RunEndT], 

344 persistence: BaseStatePersistence[StateT, RunEndT], 

345 *, 

346 state: StateT = None, 

347 deps: DepsT = None, 

348 infer_name: bool = True, 

349 ) -> BaseNode[StateT, DepsT, Any] | End[RunEndT]: 

350 """Run a node in the graph and return the next node to run. 

351 

352 Args: 

353 node: The node to run. 

354 persistence: State persistence interface, defaults to 

355 [`SimpleStatePersistence`][pydantic_graph.SimpleStatePersistence] if `None`. 

356 state: The current state of the graph. 

357 deps: The dependencies of the graph. 

358 infer_name: Whether to infer the graph name from the calling frame. 

359 

360 Returns: 

361 The next node to run or [`End`][pydantic_graph.nodes.End] if the graph has finished. 

362 """ 

363 if infer_name and self.name is None: 363 ↛ 366line 363 didn't jump to line 366 because the condition on line 363 was always true

364 self._infer_name(inspect.currentframe()) 

365 

366 persistence.set_graph_types(self) 

367 run = GraphRun[StateT, DepsT, RunEndT]( 

368 graph=self, 

369 start_node=node, 

370 persistence=persistence, 

371 state=state, 

372 deps=deps, 

373 ) 

374 return await run.next(node) 

375 

376 def mermaid_code( 

377 self, 

378 *, 

379 start_node: Sequence[mermaid.NodeIdent] | mermaid.NodeIdent | None = None, 

380 title: str | None | typing_extensions.Literal[False] = None, 

381 edge_labels: bool = True, 

382 notes: bool = True, 

383 highlighted_nodes: Sequence[mermaid.NodeIdent] | mermaid.NodeIdent | None = None, 

384 highlight_css: str = mermaid.DEFAULT_HIGHLIGHT_CSS, 

385 infer_name: bool = True, 

386 direction: mermaid.StateDiagramDirection | None = None, 

387 ) -> str: 

388 """Generate a diagram representing the graph as [mermaid](https://mermaid.js.org/) diagram. 

389 

390 This method calls [`pydantic_graph.mermaid.generate_code`][pydantic_graph.mermaid.generate_code]. 

391 

392 Args: 

393 start_node: The node or nodes which can start the graph. 

394 title: The title of the diagram, use `False` to not include a title. 

395 edge_labels: Whether to include edge labels. 

396 notes: Whether to include notes on each node. 

397 highlighted_nodes: Optional node or nodes to highlight. 

398 highlight_css: The CSS to use for highlighting nodes. 

399 infer_name: Whether to infer the graph name from the calling frame. 

400 direction: The direction of flow. 

401 

402 Returns: 

403 The mermaid code for the graph, which can then be rendered as a diagram. 

404 

405 Here's an example of generating a diagram for the graph from [above][pydantic_graph.graph.Graph]: 

406 

407 ```py {title="mermaid_never_42.py" py="3.10"} 

408 from never_42 import Increment, never_42_graph 

409 

410 print(never_42_graph.mermaid_code(start_node=Increment)) 

411 ''' 

412 --- 

413 title: never_42_graph 

414 --- 

415 stateDiagram-v2 

416 [*] --> Increment 

417 Increment --> Check42 

418 Check42 --> Increment 

419 Check42 --> [*] 

420 ''' 

421 ``` 

422 

423 The rendered diagram will look like this: 

424 

425 ```mermaid 

426 --- 

427 title: never_42_graph 

428 --- 

429 stateDiagram-v2 

430 [*] --> Increment 

431 Increment --> Check42 

432 Check42 --> Increment 

433 Check42 --> [*] 

434 ``` 

435 """ 

436 if infer_name and self.name is None: 

437 self._infer_name(inspect.currentframe()) 

438 if title is None and self.name: 

439 title = self.name 

440 return mermaid.generate_code( 

441 self, 

442 start_node=start_node, 

443 highlighted_nodes=highlighted_nodes, 

444 highlight_css=highlight_css, 

445 title=title or None, 

446 edge_labels=edge_labels, 

447 notes=notes, 

448 direction=direction, 

449 ) 

450 

451 def mermaid_image( 

452 self, infer_name: bool = True, **kwargs: typing_extensions.Unpack[mermaid.MermaidConfig] 

453 ) -> bytes: 

454 """Generate a diagram representing the graph as an image. 

455 

456 The format and diagram can be customized using `kwargs`, 

457 see [`pydantic_graph.mermaid.MermaidConfig`][pydantic_graph.mermaid.MermaidConfig]. 

458 

459 !!! note "Uses external service" 

460 This method makes a request to [mermaid.ink](https://mermaid.ink) to render the image, `mermaid.ink` 

461 is a free service not affiliated with Pydantic. 

462 

463 Args: 

464 infer_name: Whether to infer the graph name from the calling frame. 

465 **kwargs: Additional arguments to pass to `mermaid.request_image`. 

466 

467 Returns: 

468 The image bytes. 

469 """ 

470 if infer_name and self.name is None: 

471 self._infer_name(inspect.currentframe()) 

472 if 'title' not in kwargs and self.name: 

473 kwargs['title'] = self.name 

474 return mermaid.request_image(self, **kwargs) 

475 

476 def mermaid_save( 

477 self, path: Path | str, /, *, infer_name: bool = True, **kwargs: typing_extensions.Unpack[mermaid.MermaidConfig] 

478 ) -> None: 

479 """Generate a diagram representing the graph and save it as an image. 

480 

481 The format and diagram can be customized using `kwargs`, 

482 see [`pydantic_graph.mermaid.MermaidConfig`][pydantic_graph.mermaid.MermaidConfig]. 

483 

484 !!! note "Uses external service" 

485 This method makes a request to [mermaid.ink](https://mermaid.ink) to render the image, `mermaid.ink` 

486 is a free service not affiliated with Pydantic. 

487 

488 Args: 

489 path: The path to save the image to. 

490 infer_name: Whether to infer the graph name from the calling frame. 

491 **kwargs: Additional arguments to pass to `mermaid.save_image`. 

492 """ 

493 if infer_name and self.name is None: 

494 self._infer_name(inspect.currentframe()) 

495 if 'title' not in kwargs and self.name: 

496 kwargs['title'] = self.name 

497 mermaid.save_image(path, self, **kwargs) 

498 

499 def get_nodes(self) -> Sequence[type[BaseNode[StateT, DepsT, RunEndT]]]: 

500 """Get the nodes in the graph.""" 

501 return [node_def.node for node_def in self.node_defs.values()] 

502 

503 @cached_property 

504 def inferred_types(self) -> tuple[type[StateT], type[RunEndT]]: 

505 # Get the types of the state and run end from the graph. 

506 if _utils.is_set(self._state_type) and _utils.is_set(self._run_end_type): 

507 return self._state_type, self._run_end_type 

508 

509 state_type = self._state_type 

510 run_end_type = self._run_end_type 

511 

512 for node_def in self.node_defs.values(): 

513 for base in typing_extensions.get_original_bases(node_def.node): 

514 if typing_extensions.get_origin(base) is BaseNode: 

515 args = typing_extensions.get_args(base) 

516 if not _utils.is_set(state_type) and args: 

517 state_type = args[0] 

518 

519 if not _utils.is_set(run_end_type) and len(args) == 3: 

520 t = args[2] 

521 if not typing_objects.is_never(t): 

522 run_end_type = t 

523 if _utils.is_set(state_type) and _utils.is_set(run_end_type): 

524 return state_type, run_end_type # pyright: ignore[reportReturnType] 

525 # break the inner (bases) loop 

526 break 

527 

528 if not _utils.is_set(state_type): 528 ↛ 531line 528 didn't jump to line 531 because the condition on line 528 was always true

529 # state defaults to None, so use that if we can't infer it 

530 state_type = None 

531 if not _utils.is_set(run_end_type): 

532 # this happens if a graph has no return nodes, use None so any downstream errors are clear 

533 run_end_type = None 

534 return state_type, run_end_type # pyright: ignore[reportReturnType] 

535 

536 def _register_node( 

537 self, 

538 node: type[BaseNode[StateT, DepsT, RunEndT]], 

539 parent_namespace: dict[str, Any] | None, 

540 ) -> None: 

541 node_id = node.get_node_id() 

542 if existing_node := self.node_defs.get(node_id): 

543 raise exceptions.GraphSetupError( 

544 f'Node ID `{node_id}` is not unique — found on {existing_node.node} and {node}' 

545 ) 

546 else: 

547 self.node_defs[node_id] = node.get_node_def(parent_namespace) 

548 

549 def _validate_edges(self): 

550 known_node_ids = self.node_defs.keys() 

551 bad_edges: dict[str, list[str]] = {} 

552 

553 for node_id, node_def in self.node_defs.items(): 

554 for edge in node_def.next_node_edges.keys(): 

555 if edge not in known_node_ids: 

556 bad_edges.setdefault(edge, []).append(f'`{node_id}`') 

557 

558 if bad_edges: 

559 bad_edges_list = [f'`{k}` is referenced by {_utils.comma_and(v)}' for k, v in bad_edges.items()] 

560 if len(bad_edges_list) == 1: 

561 raise exceptions.GraphSetupError(f'{bad_edges_list[0]} but not included in the graph.') 

562 else: 

563 b = '\n'.join(f' {be}' for be in bad_edges_list) 

564 raise exceptions.GraphSetupError( 

565 f'Nodes are referenced in the graph but not included in the graph:\n{b}' 

566 ) 

567 

568 def _infer_name(self, function_frame: types.FrameType | None) -> None: 

569 """Infer the agent name from the call frame. 

570 

571 Usage should be `self._infer_name(inspect.currentframe())`. 

572 

573 Copied from `Agent`. 

574 """ 

575 assert self.name is None, 'Name already set' 

576 if function_frame is not None and (parent_frame := function_frame.f_back): # pragma: no branch 

577 for name, item in parent_frame.f_locals.items(): 

578 if item is self: 

579 self.name = name 

580 return 

581 if parent_frame.f_locals != parent_frame.f_globals: 581 ↛ exitline 581 didn't return from function '_infer_name' because the condition on line 581 was always true

582 # if we couldn't find the agent in locals and globals are a different dict, try globals 

583 for name, item in parent_frame.f_globals.items(): 583 ↛ exitline 583 didn't return from function '_infer_name' because the loop on line 583 didn't complete

584 if item is self: 

585 self.name = name 

586 return 

587 

588 

589class GraphRun(Generic[StateT, DepsT, RunEndT]): 

590 """A stateful, async-iterable run of a [`Graph`][pydantic_graph.graph.Graph]. 

591 

592 You typically get a `GraphRun` instance from calling 

593 `async with [my_graph.iter(...)][pydantic_graph.graph.Graph.iter] as graph_run:`. That gives you the ability to iterate 

594 through nodes as they run, either by `async for` iteration or by repeatedly calling `.next(...)`. 

595 

596 Here's an example of iterating over the graph from [above][pydantic_graph.graph.Graph]: 

597 ```py {title="iter_never_42.py" noqa="I001" py="3.10"} 

598 from copy import deepcopy 

599 from never_42 import Increment, MyState, never_42_graph 

600 

601 async def main(): 

602 state = MyState(1) 

603 async with never_42_graph.iter(Increment(), state=state) as graph_run: 

604 node_states = [(graph_run.next_node, deepcopy(graph_run.state))] 

605 async for node in graph_run: 

606 node_states.append((node, deepcopy(graph_run.state))) 

607 print(node_states) 

608 ''' 

609 [ 

610 (Increment(), MyState(number=1)), 

611 (Check42(), MyState(number=2)), 

612 (End(data=2), MyState(number=2)), 

613 ] 

614 ''' 

615 

616 state = MyState(41) 

617 async with never_42_graph.iter(Increment(), state=state) as graph_run: 

618 node_states = [(graph_run.next_node, deepcopy(graph_run.state))] 

619 async for node in graph_run: 

620 node_states.append((node, deepcopy(graph_run.state))) 

621 print(node_states) 

622 ''' 

623 [ 

624 (Increment(), MyState(number=41)), 

625 (Check42(), MyState(number=42)), 

626 (Increment(), MyState(number=42)), 

627 (Check42(), MyState(number=43)), 

628 (End(data=43), MyState(number=43)), 

629 ] 

630 ''' 

631 ``` 

632 

633 See the [`GraphRun.next` documentation][pydantic_graph.graph.GraphRun.next] for an example of how to manually 

634 drive the graph run. 

635 """ 

636 

637 def __init__( 

638 self, 

639 *, 

640 graph: Graph[StateT, DepsT, RunEndT], 

641 start_node: BaseNode[StateT, DepsT, RunEndT], 

642 persistence: BaseStatePersistence[StateT, RunEndT], 

643 state: StateT, 

644 deps: DepsT, 

645 snapshot_id: str | None = None, 

646 ): 

647 """Create a new run for a given graph, starting at the specified node. 

648 

649 Typically, you'll use [`Graph.iter`][pydantic_graph.graph.Graph.iter] rather than calling this directly. 

650 

651 Args: 

652 graph: The [`Graph`][pydantic_graph.graph.Graph] to run. 

653 start_node: The node where execution will begin. 

654 persistence: State persistence interface. 

655 state: A shared state object or primitive (like a counter, dataclass, etc.) that is available 

656 to all nodes via `ctx.state`. 

657 deps: Optional dependencies that each node can access via `ctx.deps`, e.g. database connections, 

658 configuration, or logging clients. 

659 snapshot_id: The ID of the snapshot the node came from. 

660 """ 

661 self.graph = graph 

662 self.persistence = persistence 

663 self._snapshot_id: str | None = snapshot_id 

664 self.state = state 

665 self.deps = deps 

666 

667 self._next_node: BaseNode[StateT, DepsT, RunEndT] | End[RunEndT] = start_node 

668 

669 @property 

670 def next_node(self) -> BaseNode[StateT, DepsT, RunEndT] | End[RunEndT]: 

671 """The next node that will be run in the graph. 

672 

673 This is the next node that will be used during async iteration, or if a node is not passed to `self.next(...)`. 

674 """ 

675 return self._next_node 

676 

677 @property 

678 def result(self) -> GraphRunResult[StateT, RunEndT] | None: 

679 """The final result of the graph run if the run is completed, otherwise `None`.""" 

680 if not isinstance(self._next_node, End): 

681 return None # The GraphRun has not finished running 

682 return GraphRunResult( 

683 self._next_node.data, 

684 state=self.state, 

685 persistence=self.persistence, 

686 ) 

687 

688 async def next( 

689 self, node: BaseNode[StateT, DepsT, RunEndT] | None = None 

690 ) -> BaseNode[StateT, DepsT, RunEndT] | End[RunEndT]: 

691 """Manually drive the graph run by passing in the node you want to run next. 

692 

693 This lets you inspect or mutate the node before continuing execution, or skip certain nodes 

694 under dynamic conditions. The graph run should stop when you return an [`End`][pydantic_graph.nodes.End] node. 

695 

696 Here's an example of using `next` to drive the graph from [above][pydantic_graph.graph.Graph]: 

697 ```py {title="next_never_42.py" noqa="I001" py="3.10"} 

698 from copy import deepcopy 

699 from pydantic_graph import End 

700 from never_42 import Increment, MyState, never_42_graph 

701 

702 async def main(): 

703 state = MyState(48) 

704 async with never_42_graph.iter(Increment(), state=state) as graph_run: 

705 next_node = graph_run.next_node # start with the first node 

706 node_states = [(next_node, deepcopy(graph_run.state))] 

707 

708 while not isinstance(next_node, End): 

709 if graph_run.state.number == 50: 

710 graph_run.state.number = 42 

711 next_node = await graph_run.next(next_node) 

712 node_states.append((next_node, deepcopy(graph_run.state))) 

713 

714 print(node_states) 

715 ''' 

716 [ 

717 (Increment(), MyState(number=48)), 

718 (Check42(), MyState(number=49)), 

719 (End(data=49), MyState(number=49)), 

720 ] 

721 ''' 

722 ``` 

723 

724 Args: 

725 node: The node to run next in the graph. If not specified, uses `self.next_node`, which is initialized to 

726 the `start_node` of the run and updated each time a new node is returned. 

727 

728 Returns: 

729 The next node returned by the graph logic, or an [`End`][pydantic_graph.nodes.End] node if 

730 the run has completed. 

731 """ 

732 if node is None: 

733 # This cast is necessary because self._next_node could be an `End`. You'll get a runtime error if that's 

734 # the case, but if it is, the only way to get there would be to have tried calling next manually after 

735 # the run finished. Either way, maybe it would be better to not do this cast... 

736 node = cast(BaseNode[StateT, DepsT, RunEndT], self._next_node) 

737 node_snapshot_id = node.get_snapshot_id() 

738 else: 

739 node_snapshot_id = node.get_snapshot_id() 

740 

741 if node_snapshot_id != self._snapshot_id: 

742 await self.persistence.snapshot_node_if_new(node_snapshot_id, self.state, node) 

743 self._snapshot_id = node_snapshot_id 

744 

745 if not isinstance(node, BaseNode): 

746 # While technically this is not compatible with the documented method signature, it's an easy mistake to 

747 # make, and we should eagerly provide a more helpful error message than you'd get otherwise. 

748 raise TypeError(f'`next` must be called with a `BaseNode` instance, got {node!r}.') 

749 

750 node_id = node.get_node_id() 

751 if node_id not in self.graph.node_defs: 

752 raise exceptions.GraphRuntimeError(f'Node `{node}` is not in the graph.') 

753 

754 with ExitStack() as stack: 

755 if self.graph.auto_instrument: 

756 stack.enter_context(_logfire.span('run node {node_id}', node_id=node_id, node=node)) 

757 

758 async with self.persistence.record_run(node_snapshot_id): 

759 ctx = GraphRunContext(self.state, self.deps) 

760 self._next_node = await node.run(ctx) 

761 

762 if isinstance(self._next_node, End): 

763 self._snapshot_id = self._next_node.get_snapshot_id() 

764 await self.persistence.snapshot_end(self.state, self._next_node) 

765 elif isinstance(self._next_node, BaseNode): 

766 self._snapshot_id = self._next_node.get_snapshot_id() 

767 await self.persistence.snapshot_node(self.state, self._next_node) 

768 else: 

769 raise exceptions.GraphRuntimeError( 

770 f'Invalid node return type: `{type(self._next_node).__name__}`. Expected `BaseNode` or `End`.' 

771 ) 

772 

773 return self._next_node 

774 

775 def __aiter__(self) -> AsyncIterator[BaseNode[StateT, DepsT, RunEndT] | End[RunEndT]]: 

776 return self 

777 

778 async def __anext__(self) -> BaseNode[StateT, DepsT, RunEndT] | End[RunEndT]: 

779 """Use the last returned node as the input to `Graph.next`.""" 

780 if isinstance(self._next_node, End): 

781 raise StopAsyncIteration 

782 return await self.next(self._next_node) 

783 

784 def __repr__(self) -> str: 

785 return f'<GraphRun graph={self.graph.name or "[unnamed]"}>' 

786 

787 

788@dataclass 

789class GraphRunResult(Generic[StateT, RunEndT]): 

790 """The final result of running a graph.""" 

791 

792 output: RunEndT 

793 state: StateT 

794 persistence: BaseStatePersistence[StateT, RunEndT] = field(repr=False)