Coverage for tests/graph/test_history.py: 96.61%

59 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2025-01-25 16:43 +0000

1# pyright: reportPrivateUsage=false 

2from __future__ import annotations as _annotations 

3 

4import json 

5from dataclasses import dataclass 

6from datetime import datetime, timezone 

7 

8import pytest 

9from dirty_equals import IsStr 

10from inline_snapshot import snapshot 

11 

12from pydantic_graph import BaseNode, End, EndStep, Graph, GraphRunContext, GraphSetupError, NodeStep 

13 

14from ..conftest import IsFloat, IsNow 

15 

16pytestmark = pytest.mark.anyio 

17 

18 

19@dataclass 

20class MyState: 

21 x: int 

22 y: str 

23 

24 

25@dataclass 

26class Foo(BaseNode[MyState]): 

27 async def run(self, ctx: GraphRunContext[MyState]) -> Bar: 

28 ctx.state.x += 1 

29 return Bar() 

30 

31 

32@dataclass 

33class Bar(BaseNode[MyState, None, int]): 

34 async def run(self, ctx: GraphRunContext[MyState]) -> End[int]: 

35 ctx.state.y += 'y' 

36 return End(ctx.state.x * 2) 

37 

38 

39@pytest.mark.parametrize( 

40 'graph', 

41 [ 

42 Graph(nodes=(Foo, Bar), state_type=MyState, run_end_type=int), 

43 Graph(nodes=(Foo, Bar), state_type=MyState), 

44 Graph(nodes=(Foo, Bar), run_end_type=int), 

45 Graph(nodes=(Foo, Bar)), 

46 ], 

47) 

48async def test_dump_load_history(graph: Graph[MyState, None, int]): 

49 result, history = await graph.run(Foo(), state=MyState(1, '')) 

50 assert result == snapshot(4) 

51 assert history == snapshot( 

52 [ 

53 NodeStep(state=MyState(x=2, y=''), node=Foo(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat()), 

54 NodeStep(state=MyState(x=2, y='y'), node=Bar(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat()), 

55 EndStep(result=End(4), ts=IsNow(tz=timezone.utc)), 

56 ] 

57 ) 

58 history_json = graph.dump_history(history) 

59 assert json.loads(history_json) == snapshot( 

60 [ 

61 { 

62 'state': {'x': 2, 'y': ''}, 

63 'node': {'node_id': 'Foo'}, 

64 'start_ts': IsStr(regex=r'20\d\d-\d\d-\d\dT.+'), 

65 'duration': IsFloat(), 

66 'kind': 'node', 

67 }, 

68 { 

69 'state': {'x': 2, 'y': 'y'}, 

70 'node': {'node_id': 'Bar'}, 

71 'start_ts': IsStr(regex=r'20\d\d-\d\d-\d\dT.+'), 

72 'duration': IsFloat(), 

73 'kind': 'node', 

74 }, 

75 {'result': {'data': 4}, 'ts': IsStr(regex=r'20\d\d-\d\d-\d\dT.+'), 'kind': 'end'}, 

76 ] 

77 ) 

78 history_loaded = graph.load_history(history_json) 

79 assert history == history_loaded 

80 

81 custom_history = [ 

82 { 

83 'state': {'x': 2, 'y': ''}, 

84 'node': {'node_id': 'Foo'}, 

85 'start_ts': '2025-01-01T00:00:00Z', 

86 'duration': 123, 

87 'kind': 'node', 

88 }, 

89 {'result': {'data': '42'}, 'ts': '2025-01-01T00:00:00Z', 'kind': 'end'}, 

90 ] 

91 history_loaded = graph.load_history(json.dumps(custom_history)) 

92 assert history_loaded == snapshot( 

93 [ 

94 NodeStep( 

95 state=MyState(x=2, y=''), 

96 node=Foo(), 

97 start_ts=datetime(2025, 1, 1, 0, 0, tzinfo=timezone.utc), 

98 duration=123.0, 

99 ), 

100 EndStep(result=End(data=42), ts=datetime(2025, 1, 1, 0, 0, tzinfo=timezone.utc)), 

101 ] 

102 ) 

103 

104 

105def test_one_node(): 

106 @dataclass 

107 class MyNode(BaseNode[None, None, int]): 

108 async def run(self, ctx: GraphRunContext) -> End[int]: 

109 return End(123) 

110 

111 g = Graph(nodes=[MyNode]) 

112 

113 custom_history = [ 

114 {'result': {'data': '123'}, 'ts': '2025-01-01T00:00:00Z', 'kind': 'end'}, 

115 ] 

116 history_loaded = g.load_history(json.dumps(custom_history)) 

117 assert history_loaded == snapshot( 

118 [ 

119 EndStep(result=End(data=123), ts=datetime(2025, 1, 1, 0, 0, tzinfo=timezone.utc)), 

120 ] 

121 ) 

122 

123 

124def test_no_generic_arg(): 

125 @dataclass 

126 class NoGenericArgsNode(BaseNode): 

127 async def run(self, ctx: GraphRunContext) -> NoGenericArgsNode: 

128 return NoGenericArgsNode() 

129 

130 g = Graph(nodes=[NoGenericArgsNode]) 

131 assert g._get_state_type() is type(None) 

132 with pytest.raises(GraphSetupError, match='Could not infer run end type from nodes, please set `run_end_type`.'): 

133 g._get_run_end_type() 

134 

135 g = Graph(nodes=[NoGenericArgsNode], run_end_type=None) # pyright: ignore[reportArgumentType] 

136 assert g._get_run_end_type() is None 

137 

138 custom_history = [ 

139 {'result': {'data': None}, 'ts': '2025-01-01T00:00:00Z', 'kind': 'end'}, 

140 ] 

141 history_loaded = g.load_history(json.dumps(custom_history)) 

142 assert history_loaded == snapshot( 

143 [ 

144 EndStep(result=End(data=None), ts=datetime(2025, 1, 1, 0, 0, tzinfo=timezone.utc)), 

145 ] 

146 )