Coverage for pydantic_graph/pydantic_graph/graph.py: 97.78%

155 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2025-01-25 16:43 +0000

1from __future__ import annotations as _annotations 

2 

3import asyncio 

4import inspect 

5import types 

6from collections.abc import Sequence 

7from dataclasses import dataclass, field 

8from functools import cached_property 

9from pathlib import Path 

10from time import perf_counter 

11from typing import TYPE_CHECKING, Annotated, Any, Callable, Generic, TypeVar 

12 

13import logfire_api 

14import pydantic 

15import typing_extensions 

16 

17from . import _utils, exceptions, mermaid 

18from .nodes import BaseNode, DepsT, End, GraphRunContext, NodeDef, RunEndT 

19from .state import EndStep, HistoryStep, NodeStep, StateT, deep_copy_state, nodes_schema_var 

20 

21__all__ = ('Graph',) 

22 

23_logfire = logfire_api.Logfire(otel_scope='pydantic-graph') 

24 

25T = TypeVar('T') 

26"""An invariant typevar.""" 

27 

28 

29@dataclass(init=False) 

30class Graph(Generic[StateT, DepsT, RunEndT]): 

31 """Definition of a graph. 

32 

33 In `pydantic-graph`, a graph is a collection of nodes that can be run in sequence. The nodes define 

34 their outgoing edges — e.g. which nodes may be run next, and thereby the structure of the graph. 

35 

36 Here's a very simple example of a graph which increments a number by 1, but makes sure the number is never 

37 42 at the end. 

38 

39 ```py {title="never_42.py" noqa="I001" py="3.10"} 

40 from __future__ import annotations 

41 

42 from dataclasses import dataclass 

43 

44 from pydantic_graph import BaseNode, End, Graph, GraphRunContext 

45 

46 @dataclass 

47 class MyState: 

48 number: int 

49 

50 @dataclass 

51 class Increment(BaseNode[MyState]): 

52 async def run(self, ctx: GraphRunContext) -> Check42: 

53 ctx.state.number += 1 

54 return Check42() 

55 

56 @dataclass 

57 class Check42(BaseNode[MyState, None, int]): 

58 async def run(self, ctx: GraphRunContext) -> Increment | End[int]: 

59 if ctx.state.number == 42: 

60 return Increment() 

61 else: 

62 return End(ctx.state.number) 

63 

64 never_42_graph = Graph(nodes=(Increment, Check42)) 

65 ``` 

66 _(This example is complete, it can be run "as is")_ 

67 

68 See [`run`][pydantic_graph.graph.Graph.run] For an example of running graph, and 

69 [`mermaid_code`][pydantic_graph.graph.Graph.mermaid_code] for an example of generating a mermaid diagram 

70 from the graph. 

71 """ 

72 

73 name: str | None 

74 node_defs: dict[str, NodeDef[StateT, DepsT, RunEndT]] 

75 snapshot_state: Callable[[StateT], StateT] 

76 _state_type: type[StateT] | _utils.Unset = field(repr=False) 

77 _run_end_type: type[RunEndT] | _utils.Unset = field(repr=False) 

78 

79 def __init__( 

80 self, 

81 *, 

82 nodes: Sequence[type[BaseNode[StateT, DepsT, RunEndT]]], 

83 name: str | None = None, 

84 state_type: type[StateT] | _utils.Unset = _utils.UNSET, 

85 run_end_type: type[RunEndT] | _utils.Unset = _utils.UNSET, 

86 snapshot_state: Callable[[StateT], StateT] = deep_copy_state, 

87 ): 

88 """Create a graph from a sequence of nodes. 

89 

90 Args: 

91 nodes: The nodes which make up the graph, nodes need to be unique and all be generic in the same 

92 state type. 

93 name: Optional name for the graph, if not provided the name will be inferred from the calling frame 

94 on the first call to a graph method. 

95 state_type: The type of the state for the graph, this can generally be inferred from `nodes`. 

96 run_end_type: The type of the result of running the graph, this can generally be inferred from `nodes`. 

97 snapshot_state: A function to snapshot the state of the graph, this is used in 

98 [`NodeStep`][pydantic_graph.state.NodeStep] and [`EndStep`][pydantic_graph.state.EndStep] to record 

99 the state before each step. 

100 """ 

101 self.name = name 

102 self._state_type = state_type 

103 self._run_end_type = run_end_type 

104 self.snapshot_state = snapshot_state 

105 

106 parent_namespace = _utils.get_parent_namespace(inspect.currentframe()) 

107 self.node_defs: dict[str, NodeDef[StateT, DepsT, RunEndT]] = {} 

108 for node in nodes: 

109 self._register_node(node, parent_namespace) 

110 

111 self._validate_edges() 

112 

113 async def run( 

114 self: Graph[StateT, DepsT, T], 

115 start_node: BaseNode[StateT, DepsT, T], 

116 *, 

117 state: StateT = None, 

118 deps: DepsT = None, 

119 infer_name: bool = True, 

120 ) -> tuple[T, list[HistoryStep[StateT, T]]]: 

121 """Run the graph from a starting node until it ends. 

122 

123 Args: 

124 start_node: the first node to run, since the graph definition doesn't define the entry point in the graph, 

125 you need to provide the starting node. 

126 state: The initial state of the graph. 

127 deps: The dependencies of the graph. 

128 infer_name: Whether to infer the graph name from the calling frame. 

129 

130 Returns: 

131 The result type from ending the run and the history of the run. 

132 

133 Here's an example of running the graph from [above][pydantic_graph.graph.Graph]: 

134 

135 ```py {title="run_never_42.py" noqa="I001" py="3.10"} 

136 from never_42 import Increment, MyState, never_42_graph 

137 

138 async def main(): 

139 state = MyState(1) 

140 _, history = await never_42_graph.run(Increment(), state=state) 

141 print(state) 

142 #> MyState(number=2) 

143 print(len(history)) 

144 #> 3 

145 

146 state = MyState(41) 

147 _, history = await never_42_graph.run(Increment(), state=state) 

148 print(state) 

149 #> MyState(number=43) 

150 print(len(history)) 

151 #> 5 

152 ``` 

153 """ 

154 if infer_name and self.name is None: 

155 self._infer_name(inspect.currentframe()) 

156 

157 history: list[HistoryStep[StateT, T]] = [] 

158 with _logfire.span( 

159 '{graph_name} run {start=}', 

160 graph_name=self.name or 'graph', 

161 start=start_node, 

162 ) as run_span: 

163 while True: 

164 next_node = await self.next(start_node, history, state=state, deps=deps, infer_name=False) 

165 if isinstance(next_node, End): 

166 history.append(EndStep(result=next_node)) 

167 run_span.set_attribute('history', history) 

168 return next_node.data, history 

169 elif isinstance(next_node, BaseNode): 

170 start_node = next_node 

171 else: 

172 if TYPE_CHECKING: 

173 typing_extensions.assert_never(next_node) 

174 else: 

175 raise exceptions.GraphRuntimeError( 

176 f'Invalid node return type: `{type(next_node).__name__}`. Expected `BaseNode` or `End`.' 

177 ) 

178 

179 def run_sync( 

180 self: Graph[StateT, DepsT, T], 

181 start_node: BaseNode[StateT, DepsT, T], 

182 *, 

183 state: StateT = None, 

184 deps: DepsT = None, 

185 infer_name: bool = True, 

186 ) -> tuple[T, list[HistoryStep[StateT, T]]]: 

187 """Run the graph synchronously. 

188 

189 This is a convenience method that wraps [`self.run`][pydantic_graph.Graph.run] with `loop.run_until_complete(...)`. 

190 You therefore can't use this method inside async code or if there's an active event loop. 

191 

192 Args: 

193 start_node: the first node to run, since the graph definition doesn't define the entry point in the graph, 

194 you need to provide the starting node. 

195 state: The initial state of the graph. 

196 deps: The dependencies of the graph. 

197 infer_name: Whether to infer the graph name from the calling frame. 

198 

199 Returns: 

200 The result type from ending the run and the history of the run. 

201 """ 

202 if infer_name and self.name is None: 202 ↛ 204line 202 didn't jump to line 204 because the condition on line 202 was always true

203 self._infer_name(inspect.currentframe()) 

204 return asyncio.get_event_loop().run_until_complete( 

205 self.run(start_node, state=state, deps=deps, infer_name=False) 

206 ) 

207 

208 async def next( 

209 self: Graph[StateT, DepsT, T], 

210 node: BaseNode[StateT, DepsT, T], 

211 history: list[HistoryStep[StateT, T]], 

212 *, 

213 state: StateT = None, 

214 deps: DepsT = None, 

215 infer_name: bool = True, 

216 ) -> BaseNode[StateT, DepsT, Any] | End[T]: 

217 """Run a node in the graph and return the next node to run. 

218 

219 Args: 

220 node: The node to run. 

221 history: The history of the graph run so far. NOTE: this will be mutated to add the new step. 

222 state: The current state of the graph. 

223 deps: The dependencies of the graph. 

224 infer_name: Whether to infer the graph name from the calling frame. 

225 

226 Returns: 

227 The next node to run or [`End`][pydantic_graph.nodes.End] if the graph has finished. 

228 """ 

229 if infer_name and self.name is None: 

230 self._infer_name(inspect.currentframe()) 

231 node_id = node.get_id() 

232 if node_id not in self.node_defs: 

233 raise exceptions.GraphRuntimeError(f'Node `{node}` is not in the graph.') 

234 

235 ctx = GraphRunContext(state, deps) 

236 with _logfire.span('run node {node_id}', node_id=node_id, node=node): 

237 start_ts = _utils.now_utc() 

238 start = perf_counter() 

239 next_node = await node.run(ctx) 

240 duration = perf_counter() - start 

241 

242 history.append( 

243 NodeStep(state=state, node=node, start_ts=start_ts, duration=duration, snapshot_state=self.snapshot_state) 

244 ) 

245 return next_node 

246 

247 def dump_history( 

248 self: Graph[StateT, DepsT, T], history: list[HistoryStep[StateT, T]], *, indent: int | None = None 

249 ) -> bytes: 

250 """Dump the history of a graph run as JSON. 

251 

252 Args: 

253 history: The history of the graph run. 

254 indent: The number of spaces to indent the JSON. 

255 

256 Returns: 

257 The JSON representation of the history. 

258 """ 

259 return self.history_type_adapter.dump_json(history, indent=indent) 

260 

261 def load_history(self, json_bytes: str | bytes | bytearray) -> list[HistoryStep[StateT, RunEndT]]: 

262 """Load the history of a graph run from JSON. 

263 

264 Args: 

265 json_bytes: The JSON representation of the history. 

266 

267 Returns: 

268 The history of the graph run. 

269 """ 

270 return self.history_type_adapter.validate_json(json_bytes) 

271 

272 @cached_property 

273 def history_type_adapter(self) -> pydantic.TypeAdapter[list[HistoryStep[StateT, RunEndT]]]: 

274 nodes = [node_def.node for node_def in self.node_defs.values()] 

275 state_t = self._get_state_type() 

276 end_t = self._get_run_end_type() 

277 token = nodes_schema_var.set(nodes) 

278 try: 

279 ta = pydantic.TypeAdapter(list[Annotated[HistoryStep[state_t, end_t], pydantic.Discriminator('kind')]]) 

280 finally: 

281 nodes_schema_var.reset(token) 

282 return ta 

283 

284 def mermaid_code( 

285 self, 

286 *, 

287 start_node: Sequence[mermaid.NodeIdent] | mermaid.NodeIdent | None = None, 

288 title: str | None | typing_extensions.Literal[False] = None, 

289 edge_labels: bool = True, 

290 notes: bool = True, 

291 highlighted_nodes: Sequence[mermaid.NodeIdent] | mermaid.NodeIdent | None = None, 

292 highlight_css: str = mermaid.DEFAULT_HIGHLIGHT_CSS, 

293 infer_name: bool = True, 

294 direction: mermaid.StateDiagramDirection | None = None, 

295 ) -> str: 

296 """Generate a diagram representing the graph as [mermaid](https://mermaid.js.org/) diagram. 

297 

298 This method calls [`pydantic_graph.mermaid.generate_code`][pydantic_graph.mermaid.generate_code]. 

299 

300 Args: 

301 start_node: The node or nodes which can start the graph. 

302 title: The title of the diagram, use `False` to not include a title. 

303 edge_labels: Whether to include edge labels. 

304 notes: Whether to include notes on each node. 

305 highlighted_nodes: Optional node or nodes to highlight. 

306 highlight_css: The CSS to use for highlighting nodes. 

307 infer_name: Whether to infer the graph name from the calling frame. 

308 direction: The direction of flow. 

309 

310 Returns: 

311 The mermaid code for the graph, which can then be rendered as a diagram. 

312 

313 Here's an example of generating a diagram for the graph from [above][pydantic_graph.graph.Graph]: 

314 

315 ```py {title="never_42.py" py="3.10"} 

316 from never_42 import Increment, never_42_graph 

317 

318 print(never_42_graph.mermaid_code(start_node=Increment)) 

319 ''' 

320 --- 

321 title: never_42_graph 

322 --- 

323 stateDiagram-v2 

324 [*] --> Increment 

325 Increment --> Check42 

326 Check42 --> Increment 

327 Check42 --> [*] 

328 ''' 

329 ``` 

330 

331 The rendered diagram will look like this: 

332 

333 ```mermaid 

334 --- 

335 title: never_42_graph 

336 --- 

337 stateDiagram-v2 

338 [*] --> Increment 

339 Increment --> Check42 

340 Check42 --> Increment 

341 Check42 --> [*] 

342 ``` 

343 """ 

344 if infer_name and self.name is None: 

345 self._infer_name(inspect.currentframe()) 

346 if title is None and self.name: 

347 title = self.name 

348 return mermaid.generate_code( 

349 self, 

350 start_node=start_node, 

351 highlighted_nodes=highlighted_nodes, 

352 highlight_css=highlight_css, 

353 title=title or None, 

354 edge_labels=edge_labels, 

355 notes=notes, 

356 direction=direction, 

357 ) 

358 

359 def mermaid_image( 

360 self, infer_name: bool = True, **kwargs: typing_extensions.Unpack[mermaid.MermaidConfig] 

361 ) -> bytes: 

362 """Generate a diagram representing the graph as an image. 

363 

364 The format and diagram can be customized using `kwargs`, 

365 see [`pydantic_graph.mermaid.MermaidConfig`][pydantic_graph.mermaid.MermaidConfig]. 

366 

367 !!! note "Uses external service" 

368 This method makes a request to [mermaid.ink](https://mermaid.ink) to render the image, `mermaid.ink` 

369 is a free service not affiliated with Pydantic. 

370 

371 Args: 

372 infer_name: Whether to infer the graph name from the calling frame. 

373 **kwargs: Additional arguments to pass to `mermaid.request_image`. 

374 

375 Returns: 

376 The image bytes. 

377 """ 

378 if infer_name and self.name is None: 

379 self._infer_name(inspect.currentframe()) 

380 if 'title' not in kwargs and self.name: 

381 kwargs['title'] = self.name 

382 return mermaid.request_image(self, **kwargs) 

383 

384 def mermaid_save( 

385 self, path: Path | str, /, *, infer_name: bool = True, **kwargs: typing_extensions.Unpack[mermaid.MermaidConfig] 

386 ) -> None: 

387 """Generate a diagram representing the graph and save it as an image. 

388 

389 The format and diagram can be customized using `kwargs`, 

390 see [`pydantic_graph.mermaid.MermaidConfig`][pydantic_graph.mermaid.MermaidConfig]. 

391 

392 !!! note "Uses external service" 

393 This method makes a request to [mermaid.ink](https://mermaid.ink) to render the image, `mermaid.ink` 

394 is a free service not affiliated with Pydantic. 

395 

396 Args: 

397 path: The path to save the image to. 

398 infer_name: Whether to infer the graph name from the calling frame. 

399 **kwargs: Additional arguments to pass to `mermaid.save_image`. 

400 """ 

401 if infer_name and self.name is None: 

402 self._infer_name(inspect.currentframe()) 

403 if 'title' not in kwargs and self.name: 

404 kwargs['title'] = self.name 

405 mermaid.save_image(path, self, **kwargs) 

406 

407 def _get_state_type(self) -> type[StateT]: 

408 if _utils.is_set(self._state_type): 

409 return self._state_type 

410 

411 for node_def in self.node_defs.values(): 

412 for base in typing_extensions.get_original_bases(node_def.node): 

413 if typing_extensions.get_origin(base) is BaseNode: 

414 args = typing_extensions.get_args(base) 

415 if args: 415 ↛ 418line 415 didn't jump to line 418 because the condition on line 415 was always true

416 return args[0] 

417 # break the inner (bases) loop 

418 break 

419 # state defaults to None, so use that if we can't infer it 

420 return type(None) # pyright: ignore[reportReturnType] 

421 

422 def _get_run_end_type(self) -> type[RunEndT]: 

423 if _utils.is_set(self._run_end_type): 

424 return self._run_end_type 

425 

426 for node_def in self.node_defs.values(): 

427 for base in typing_extensions.get_original_bases(node_def.node): 

428 if typing_extensions.get_origin(base) is BaseNode: 

429 args = typing_extensions.get_args(base) 

430 if len(args) == 3: 

431 t = args[2] 

432 if not _utils.is_never(t): 

433 return t 

434 # break the inner (bases) loop 

435 break 

436 raise exceptions.GraphSetupError('Could not infer run end type from nodes, please set `run_end_type`.') 

437 

438 def _register_node( 

439 self: Graph[StateT, DepsT, T], 

440 node: type[BaseNode[StateT, DepsT, T]], 

441 parent_namespace: dict[str, Any] | None, 

442 ) -> None: 

443 node_id = node.get_id() 

444 if existing_node := self.node_defs.get(node_id): 

445 raise exceptions.GraphSetupError( 

446 f'Node ID `{node_id}` is not unique — found on {existing_node.node} and {node}' 

447 ) 

448 else: 

449 self.node_defs[node_id] = node.get_node_def(parent_namespace) 

450 

451 def _validate_edges(self): 

452 known_node_ids = self.node_defs.keys() 

453 bad_edges: dict[str, list[str]] = {} 

454 

455 for node_id, node_def in self.node_defs.items(): 

456 for edge in node_def.next_node_edges.keys(): 

457 if edge not in known_node_ids: 

458 bad_edges.setdefault(edge, []).append(f'`{node_id}`') 

459 

460 if bad_edges: 

461 bad_edges_list = [f'`{k}` is referenced by {_utils.comma_and(v)}' for k, v in bad_edges.items()] 

462 if len(bad_edges_list) == 1: 

463 raise exceptions.GraphSetupError(f'{bad_edges_list[0]} but not included in the graph.') 

464 else: 

465 b = '\n'.join(f' {be}' for be in bad_edges_list) 

466 raise exceptions.GraphSetupError( 

467 f'Nodes are referenced in the graph but not included in the graph:\n{b}' 

468 ) 

469 

470 def _infer_name(self, function_frame: types.FrameType | None) -> None: 

471 """Infer the agent name from the call frame. 

472 

473 Usage should be `self._infer_name(inspect.currentframe())`. 

474 

475 Copied from `Agent`. 

476 """ 

477 assert self.name is None, 'Name already set' 

478 if function_frame is not None and (parent_frame := function_frame.f_back): # pragma: no branch 

479 for name, item in parent_frame.f_locals.items(): 

480 if item is self: 

481 self.name = name 

482 return 

483 if parent_frame.f_locals != parent_frame.f_globals: 483 ↛ exitline 483 didn't return from function '_infer_name' because the condition on line 483 was always true

484 # if we couldn't find the agent in locals and globals are a different dict, try globals 

485 for name, item in parent_frame.f_globals.items(): 485 ↛ exitline 485 didn't return from function '_infer_name' because the loop on line 485 didn't complete

486 if item is self: 

487 self.name = name 

488 return