Coverage for tests/graph/test_history.py: 96.61%
59 statements
« prev ^ index » next coverage.py v7.6.7, created at 2025-01-25 16:43 +0000
« prev ^ index » next coverage.py v7.6.7, created at 2025-01-25 16:43 +0000
1# pyright: reportPrivateUsage=false
2from __future__ import annotations as _annotations
4import json
5from dataclasses import dataclass
6from datetime import datetime, timezone
8import pytest
9from dirty_equals import IsStr
10from inline_snapshot import snapshot
12from pydantic_graph import BaseNode, End, EndStep, Graph, GraphRunContext, GraphSetupError, NodeStep
14from ..conftest import IsFloat, IsNow
16pytestmark = pytest.mark.anyio
19@dataclass
20class MyState:
21 x: int
22 y: str
25@dataclass
26class Foo(BaseNode[MyState]):
27 async def run(self, ctx: GraphRunContext[MyState]) -> Bar:
28 ctx.state.x += 1
29 return Bar()
32@dataclass
33class Bar(BaseNode[MyState, None, int]):
34 async def run(self, ctx: GraphRunContext[MyState]) -> End[int]:
35 ctx.state.y += 'y'
36 return End(ctx.state.x * 2)
39@pytest.mark.parametrize(
40 'graph',
41 [
42 Graph(nodes=(Foo, Bar), state_type=MyState, run_end_type=int),
43 Graph(nodes=(Foo, Bar), state_type=MyState),
44 Graph(nodes=(Foo, Bar), run_end_type=int),
45 Graph(nodes=(Foo, Bar)),
46 ],
47)
48async def test_dump_load_history(graph: Graph[MyState, None, int]):
49 result, history = await graph.run(Foo(), state=MyState(1, ''))
50 assert result == snapshot(4)
51 assert history == snapshot(
52 [
53 NodeStep(state=MyState(x=2, y=''), node=Foo(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat()),
54 NodeStep(state=MyState(x=2, y='y'), node=Bar(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat()),
55 EndStep(result=End(4), ts=IsNow(tz=timezone.utc)),
56 ]
57 )
58 history_json = graph.dump_history(history)
59 assert json.loads(history_json) == snapshot(
60 [
61 {
62 'state': {'x': 2, 'y': ''},
63 'node': {'node_id': 'Foo'},
64 'start_ts': IsStr(regex=r'20\d\d-\d\d-\d\dT.+'),
65 'duration': IsFloat(),
66 'kind': 'node',
67 },
68 {
69 'state': {'x': 2, 'y': 'y'},
70 'node': {'node_id': 'Bar'},
71 'start_ts': IsStr(regex=r'20\d\d-\d\d-\d\dT.+'),
72 'duration': IsFloat(),
73 'kind': 'node',
74 },
75 {'result': {'data': 4}, 'ts': IsStr(regex=r'20\d\d-\d\d-\d\dT.+'), 'kind': 'end'},
76 ]
77 )
78 history_loaded = graph.load_history(history_json)
79 assert history == history_loaded
81 custom_history = [
82 {
83 'state': {'x': 2, 'y': ''},
84 'node': {'node_id': 'Foo'},
85 'start_ts': '2025-01-01T00:00:00Z',
86 'duration': 123,
87 'kind': 'node',
88 },
89 {'result': {'data': '42'}, 'ts': '2025-01-01T00:00:00Z', 'kind': 'end'},
90 ]
91 history_loaded = graph.load_history(json.dumps(custom_history))
92 assert history_loaded == snapshot(
93 [
94 NodeStep(
95 state=MyState(x=2, y=''),
96 node=Foo(),
97 start_ts=datetime(2025, 1, 1, 0, 0, tzinfo=timezone.utc),
98 duration=123.0,
99 ),
100 EndStep(result=End(data=42), ts=datetime(2025, 1, 1, 0, 0, tzinfo=timezone.utc)),
101 ]
102 )
105def test_one_node():
106 @dataclass
107 class MyNode(BaseNode[None, None, int]):
108 async def run(self, ctx: GraphRunContext) -> End[int]:
109 return End(123)
111 g = Graph(nodes=[MyNode])
113 custom_history = [
114 {'result': {'data': '123'}, 'ts': '2025-01-01T00:00:00Z', 'kind': 'end'},
115 ]
116 history_loaded = g.load_history(json.dumps(custom_history))
117 assert history_loaded == snapshot(
118 [
119 EndStep(result=End(data=123), ts=datetime(2025, 1, 1, 0, 0, tzinfo=timezone.utc)),
120 ]
121 )
124def test_no_generic_arg():
125 @dataclass
126 class NoGenericArgsNode(BaseNode):
127 async def run(self, ctx: GraphRunContext) -> NoGenericArgsNode:
128 return NoGenericArgsNode()
130 g = Graph(nodes=[NoGenericArgsNode])
131 assert g._get_state_type() is type(None)
132 with pytest.raises(GraphSetupError, match='Could not infer run end type from nodes, please set `run_end_type`.'):
133 g._get_run_end_type()
135 g = Graph(nodes=[NoGenericArgsNode], run_end_type=None) # pyright: ignore[reportArgumentType]
136 assert g._get_run_end_type() is None
138 custom_history = [
139 {'result': {'data': None}, 'ts': '2025-01-01T00:00:00Z', 'kind': 'end'},
140 ]
141 history_loaded = g.load_history(json.dumps(custom_history))
142 assert history_loaded == snapshot(
143 [
144 EndStep(result=End(data=None), ts=datetime(2025, 1, 1, 0, 0, tzinfo=timezone.utc)),
145 ]
146 )