Coverage for tests/graph/test_graph.py: 99.54%
214 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-28 17:27 +0000
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-28 17:27 +0000
1# pyright: reportPrivateUsage=false
2from __future__ import annotations as _annotations
4from dataclasses import dataclass
5from datetime import timezone
6from functools import cache
7from typing import Union
9import pytest
10from dirty_equals import IsStr
11from inline_snapshot import snapshot
13from pydantic_graph import (
14 BaseNode,
15 End,
16 EndSnapshot,
17 FullStatePersistence,
18 Graph,
19 GraphRunContext,
20 GraphRuntimeError,
21 GraphSetupError,
22 NodeSnapshot,
23 SimpleStatePersistence,
24)
26from ..conftest import IsFloat, IsNow
28pytestmark = pytest.mark.anyio
31@dataclass
32class Float2String(BaseNode):
33 input_data: float
35 async def run(self, ctx: GraphRunContext) -> String2Length:
36 return String2Length(str(self.input_data))
39@dataclass
40class String2Length(BaseNode):
41 input_data: str
43 async def run(self, ctx: GraphRunContext) -> Double:
44 return Double(len(self.input_data))
47@dataclass
48class Double(BaseNode[None, None, int]):
49 input_data: int
51 async def run(self, ctx: GraphRunContext) -> Union[String2Length, End[int]]: # noqa: UP007
52 if self.input_data == 7:
53 return String2Length('x' * 21)
54 else:
55 return End(self.input_data * 2)
58async def test_graph():
59 my_graph = Graph(nodes=(Float2String, String2Length, Double))
60 assert my_graph.name is None
61 assert my_graph.inferred_types == (type(None), int)
62 result = await my_graph.run(Float2String(3.14))
63 # len('3.14') * 2 == 8
64 assert result.output == 8
65 assert my_graph.name == 'my_graph'
68async def test_graph_history(mock_snapshot_id: object):
69 my_graph = Graph[None, None, int](nodes=(Float2String, String2Length, Double))
70 assert my_graph.name is None
71 assert my_graph.inferred_types == (type(None), int)
72 sp = FullStatePersistence()
73 result = await my_graph.run(Float2String(3.14), persistence=sp)
74 # len('3.14') * 2 == 8
75 assert result.output == 8
76 assert my_graph.name == 'my_graph'
77 assert sp.history == snapshot(
78 [
79 NodeSnapshot(
80 state=None,
81 node=Float2String(input_data=3.14),
82 start_ts=IsNow(tz=timezone.utc),
83 status='success',
84 id='Float2String:1',
85 duration=IsFloat(),
86 ),
87 NodeSnapshot(
88 state=None,
89 node=String2Length(input_data='3.14'),
90 start_ts=IsNow(tz=timezone.utc),
91 status='success',
92 id='String2Length:2',
93 duration=IsFloat(),
94 ),
95 NodeSnapshot(
96 state=None,
97 node=Double(input_data=4),
98 start_ts=IsNow(tz=timezone.utc),
99 status='success',
100 id='Double:3',
101 duration=IsFloat(),
102 ),
103 EndSnapshot(state=None, result=End(data=8), ts=IsNow(tz=timezone.utc), id='end:4'),
104 ]
105 )
106 sp = FullStatePersistence()
107 result = await my_graph.run(Float2String(3.14159), persistence=sp)
108 # len('3.14159') == 7, 21 * 2 == 42
109 assert result.output == 42
110 assert sp.history == snapshot(
111 [
112 NodeSnapshot(
113 state=None,
114 node=Float2String(input_data=3.14159),
115 start_ts=IsNow(tz=timezone.utc),
116 status='success',
117 id='Float2String:5',
118 duration=IsFloat(),
119 ),
120 NodeSnapshot(
121 state=None,
122 node=String2Length(input_data='3.14159'),
123 start_ts=IsNow(tz=timezone.utc),
124 status='success',
125 id='String2Length:6',
126 duration=IsFloat(),
127 ),
128 NodeSnapshot(
129 state=None,
130 node=Double(input_data=7),
131 start_ts=IsNow(tz=timezone.utc),
132 status='success',
133 id='Double:7',
134 duration=IsFloat(),
135 ),
136 NodeSnapshot(
137 state=None,
138 node=String2Length(input_data='xxxxxxxxxxxxxxxxxxxxx'),
139 start_ts=IsNow(tz=timezone.utc),
140 status='success',
141 id='String2Length:8',
142 duration=IsFloat(),
143 ),
144 NodeSnapshot(
145 state=None,
146 node=Double(input_data=21),
147 start_ts=IsNow(tz=timezone.utc),
148 status='success',
149 id='Double:9',
150 duration=IsFloat(),
151 ),
152 EndSnapshot(state=None, result=End(data=42), ts=IsNow(tz=timezone.utc), id='end:10'),
153 ]
154 )
155 assert [e.node for e in sp.history] == snapshot(
156 [
157 Float2String(input_data=3.14159),
158 String2Length(input_data='3.14159'),
159 Double(input_data=7),
160 String2Length(input_data='xxxxxxxxxxxxxxxxxxxxx'),
161 Double(input_data=21),
162 End(data=42),
163 ]
164 )
167def test_one_bad_node():
168 class Float2String(BaseNode):
169 async def run(self, ctx: GraphRunContext) -> String2Length:
170 raise NotImplementedError()
172 class String2Length(BaseNode[None, None, None]): # pyright: ignore[reportUnusedClass]
173 async def run(self, ctx: GraphRunContext) -> End[None]:
174 raise NotImplementedError()
176 with pytest.raises(GraphSetupError) as exc_info:
177 Graph(nodes=(Float2String,))
179 assert exc_info.value.message == snapshot(
180 '`String2Length` is referenced by `Float2String` but not included in the graph.'
181 )
184def test_two_bad_nodes():
185 class Foo(BaseNode):
186 input_data: float
188 async def run(self, ctx: GraphRunContext) -> Union[Bar, Spam]: # noqa: UP007
189 raise NotImplementedError()
191 class Bar(BaseNode[None, None, None]):
192 input_data: str
194 async def run(self, ctx: GraphRunContext) -> End[None]:
195 raise NotImplementedError()
197 class Spam(BaseNode[None, None, None]):
198 async def run(self, ctx: GraphRunContext) -> End[None]:
199 raise NotImplementedError()
201 with pytest.raises(GraphSetupError) as exc_info:
202 Graph(nodes=(Foo,))
204 assert exc_info.value.message == snapshot("""\
205Nodes are referenced in the graph but not included in the graph:
206 `Bar` is referenced by `Foo`
207 `Spam` is referenced by `Foo`\
208""")
211def test_three_bad_nodes_separate():
212 class Foo(BaseNode):
213 input_data: float
215 async def run(self, ctx: GraphRunContext) -> Eggs:
216 raise NotImplementedError()
218 class Bar(BaseNode[None, None, None]):
219 input_data: str
221 async def run(self, ctx: GraphRunContext) -> Eggs:
222 raise NotImplementedError()
224 class Spam(BaseNode[None, None, None]):
225 async def run(self, ctx: GraphRunContext) -> Eggs:
226 raise NotImplementedError()
228 class Eggs(BaseNode[None, None, None]):
229 async def run(self, ctx: GraphRunContext) -> End[None]:
230 raise NotImplementedError()
232 with pytest.raises(GraphSetupError) as exc_info:
233 Graph(nodes=(Foo, Bar, Spam))
235 assert exc_info.value.message == snapshot(
236 '`Eggs` is referenced by `Foo`, `Bar`, and `Spam` but not included in the graph.'
237 )
240def test_duplicate_id():
241 class Foo(BaseNode):
242 async def run(self, ctx: GraphRunContext) -> Bar:
243 raise NotImplementedError()
245 class Bar(BaseNode[None, None, None]):
246 async def run(self, ctx: GraphRunContext) -> End[None]:
247 raise NotImplementedError()
249 @classmethod
250 @cache
251 def get_node_id(cls) -> str:
252 return 'Foo'
254 with pytest.raises(GraphSetupError) as exc_info:
255 Graph(nodes=(Foo, Bar))
257 assert exc_info.value.message == snapshot(IsStr(regex='Node ID `Foo` is not unique — found on <class.+'))
260async def test_run_node_not_in_graph():
261 @dataclass
262 class Foo(BaseNode):
263 async def run(self, ctx: GraphRunContext) -> Bar:
264 return Bar()
266 @dataclass
267 class Bar(BaseNode[None, None, None]):
268 async def run(self, ctx: GraphRunContext) -> End[None]:
269 return Spam() # type: ignore
271 @dataclass
272 class Spam(BaseNode[None, None, None]):
273 async def run(self, ctx: GraphRunContext) -> End[None]:
274 raise NotImplementedError()
276 g = Graph(nodes=(Foo, Bar))
277 with pytest.raises(GraphRuntimeError) as exc_info:
278 await g.run(Foo())
280 assert exc_info.value.message == snapshot('Node `test_run_node_not_in_graph.<locals>.Spam()` is not in the graph.')
283async def test_run_return_other(mock_snapshot_id: object):
284 @dataclass
285 class Foo(BaseNode):
286 async def run(self, ctx: GraphRunContext) -> Bar:
287 return Bar()
289 @dataclass
290 class Bar(BaseNode[None, None, None]):
291 async def run(self, ctx: GraphRunContext) -> End[None]:
292 return 42 # type: ignore
294 g = Graph(nodes=(Foo, Bar))
295 assert g.inferred_types == (type(None), type(None))
296 with pytest.raises(GraphRuntimeError) as exc_info:
297 await g.run(Foo())
299 assert exc_info.value.message == snapshot('Invalid node return type: `int`. Expected `BaseNode` or `End`.')
302async def test_iter():
303 my_graph = Graph(nodes=(Float2String, String2Length, Double))
304 assert my_graph.name is None
305 assert my_graph.inferred_types == (type(None), int)
306 node_reprs: list[str] = []
307 async with my_graph.iter(Float2String(3.14)) as graph_iter:
308 assert repr(graph_iter) == snapshot('<GraphRun graph=my_graph>')
309 async for node in graph_iter:
310 node_reprs.append(repr(node))
311 # len('3.14') * 2 == 8
312 assert graph_iter.result
313 assert graph_iter.result.output == 8
315 assert node_reprs == snapshot(["String2Length(input_data='3.14')", 'Double(input_data=4)', 'End(data=8)'])
318async def test_iter_next(mock_snapshot_id: object):
319 @dataclass
320 class Foo(BaseNode):
321 async def run(self, ctx: GraphRunContext) -> Bar:
322 return Bar()
324 @dataclass
325 class Bar(BaseNode):
326 async def run(self, ctx: GraphRunContext) -> Foo:
327 return Foo()
329 g = Graph(nodes=(Foo, Bar))
330 assert g.name is None
331 sp = FullStatePersistence()
332 async with g.iter(Foo(), persistence=sp) as run:
333 assert g.name == 'g'
334 n = await run.next()
335 assert n == Bar()
336 assert sp.history == snapshot(
337 [
338 NodeSnapshot(
339 state=None,
340 node=Foo(),
341 start_ts=IsNow(tz=timezone.utc),
342 duration=IsFloat(),
343 status='success',
344 id='Foo:1',
345 ),
346 NodeSnapshot(state=None, node=Bar(), id='Bar:2'),
347 ]
348 )
350 assert isinstance(n, Bar)
351 n2 = await run.next()
352 assert n2 == Foo()
354 assert sp.history == snapshot(
355 [
356 NodeSnapshot(
357 state=None,
358 node=Foo(),
359 start_ts=IsNow(tz=timezone.utc),
360 duration=IsFloat(),
361 status='success',
362 id='Foo:1',
363 ),
364 NodeSnapshot(
365 state=None,
366 node=Bar(),
367 start_ts=IsNow(tz=timezone.utc),
368 duration=IsFloat(),
369 status='success',
370 id='Bar:2',
371 ),
372 NodeSnapshot(state=None, node=Foo(), id='Foo:3'),
373 ]
374 )
377async def test_iter_next_error(mock_snapshot_id: object):
378 @dataclass
379 class Foo(BaseNode):
380 async def run(self, ctx: GraphRunContext) -> Bar:
381 return Bar()
383 @dataclass
384 class Bar(BaseNode[None, None, None]):
385 async def run(self, ctx: GraphRunContext) -> End[None]:
386 return End(None)
388 g = Graph(nodes=(Foo, Bar))
389 sp = SimpleStatePersistence()
390 async with g.iter(Foo(), persistence=sp) as run:
391 n = await run.next()
392 assert n == snapshot(Bar())
394 assert isinstance(n, BaseNode)
395 n = await run.next()
396 assert n == snapshot(End(data=None))
398 with pytest.raises(TypeError, match=r'`next` must be called with a `BaseNode` instance, got End\(data=None\).'):
399 await run.next()
402async def test_next(mock_snapshot_id: object):
403 @dataclass
404 class Foo(BaseNode):
405 async def run(self, ctx: GraphRunContext) -> Bar:
406 return Bar()
408 @dataclass
409 class Bar(BaseNode):
410 async def run(self, ctx: GraphRunContext) -> Foo:
411 return Foo()
413 g = Graph(nodes=(Foo, Bar))
414 assert g.name is None
415 sp = FullStatePersistence()
416 with pytest.warns(DeprecationWarning, match='`next` is deprecated, use `async with graph.iter(...)'):
417 n = await g.next(Foo(), persistence=sp) # pyright: ignore[reportDeprecated]
418 assert n == Bar()
419 assert g.name == 'g'
420 assert sp.history == snapshot(
421 [
422 NodeSnapshot(
423 state=None,
424 node=Foo(),
425 start_ts=IsNow(tz=timezone.utc),
426 duration=IsFloat(),
427 status='success',
428 id='Foo:1',
429 ),
430 NodeSnapshot(state=None, node=Bar(), id='Bar:2'),
431 ]
432 )
435async def test_deps(mock_snapshot_id: object):
436 @dataclass
437 class Deps:
438 a: int
439 b: int
441 @dataclass
442 class Foo(BaseNode[None, Deps]):
443 async def run(self, ctx: GraphRunContext[None, Deps]) -> Bar:
444 assert isinstance(ctx.deps, Deps)
445 return Bar()
447 @dataclass
448 class Bar(BaseNode[None, Deps, int]):
449 async def run(self, ctx: GraphRunContext[None, Deps]) -> End[int]:
450 assert isinstance(ctx.deps, Deps)
451 return End(123)
453 g = Graph(nodes=(Foo, Bar))
454 sp = FullStatePersistence()
455 result = await g.run(Foo(), deps=Deps(1, 2), persistence=sp)
457 assert result.output == 123
458 assert sp.history == snapshot(
459 [
460 NodeSnapshot(
461 state=None,
462 node=Foo(),
463 start_ts=IsNow(tz=timezone.utc),
464 duration=IsFloat(),
465 status='success',
466 id='Foo:1',
467 ),
468 NodeSnapshot(
469 state=None,
470 node=Bar(),
471 start_ts=IsNow(tz=timezone.utc),
472 duration=IsFloat(),
473 status='success',
474 id='Bar:2',
475 ),
476 EndSnapshot(state=None, result=End(data=123), ts=IsNow(tz=timezone.utc), id='end:3'),
477 ]
478 )