Coverage for tests/graph/test_graph.py: 99.54%

214 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-28 17:27 +0000

1# pyright: reportPrivateUsage=false 

2from __future__ import annotations as _annotations 

3 

4from dataclasses import dataclass 

5from datetime import timezone 

6from functools import cache 

7from typing import Union 

8 

9import pytest 

10from dirty_equals import IsStr 

11from inline_snapshot import snapshot 

12 

13from pydantic_graph import ( 

14 BaseNode, 

15 End, 

16 EndSnapshot, 

17 FullStatePersistence, 

18 Graph, 

19 GraphRunContext, 

20 GraphRuntimeError, 

21 GraphSetupError, 

22 NodeSnapshot, 

23 SimpleStatePersistence, 

24) 

25 

26from ..conftest import IsFloat, IsNow 

27 

28pytestmark = pytest.mark.anyio 

29 

30 

31@dataclass 

32class Float2String(BaseNode): 

33 input_data: float 

34 

35 async def run(self, ctx: GraphRunContext) -> String2Length: 

36 return String2Length(str(self.input_data)) 

37 

38 

39@dataclass 

40class String2Length(BaseNode): 

41 input_data: str 

42 

43 async def run(self, ctx: GraphRunContext) -> Double: 

44 return Double(len(self.input_data)) 

45 

46 

47@dataclass 

48class Double(BaseNode[None, None, int]): 

49 input_data: int 

50 

51 async def run(self, ctx: GraphRunContext) -> Union[String2Length, End[int]]: # noqa: UP007 

52 if self.input_data == 7: 

53 return String2Length('x' * 21) 

54 else: 

55 return End(self.input_data * 2) 

56 

57 

58async def test_graph(): 

59 my_graph = Graph(nodes=(Float2String, String2Length, Double)) 

60 assert my_graph.name is None 

61 assert my_graph.inferred_types == (type(None), int) 

62 result = await my_graph.run(Float2String(3.14)) 

63 # len('3.14') * 2 == 8 

64 assert result.output == 8 

65 assert my_graph.name == 'my_graph' 

66 

67 

68async def test_graph_history(mock_snapshot_id: object): 

69 my_graph = Graph[None, None, int](nodes=(Float2String, String2Length, Double)) 

70 assert my_graph.name is None 

71 assert my_graph.inferred_types == (type(None), int) 

72 sp = FullStatePersistence() 

73 result = await my_graph.run(Float2String(3.14), persistence=sp) 

74 # len('3.14') * 2 == 8 

75 assert result.output == 8 

76 assert my_graph.name == 'my_graph' 

77 assert sp.history == snapshot( 

78 [ 

79 NodeSnapshot( 

80 state=None, 

81 node=Float2String(input_data=3.14), 

82 start_ts=IsNow(tz=timezone.utc), 

83 status='success', 

84 id='Float2String:1', 

85 duration=IsFloat(), 

86 ), 

87 NodeSnapshot( 

88 state=None, 

89 node=String2Length(input_data='3.14'), 

90 start_ts=IsNow(tz=timezone.utc), 

91 status='success', 

92 id='String2Length:2', 

93 duration=IsFloat(), 

94 ), 

95 NodeSnapshot( 

96 state=None, 

97 node=Double(input_data=4), 

98 start_ts=IsNow(tz=timezone.utc), 

99 status='success', 

100 id='Double:3', 

101 duration=IsFloat(), 

102 ), 

103 EndSnapshot(state=None, result=End(data=8), ts=IsNow(tz=timezone.utc), id='end:4'), 

104 ] 

105 ) 

106 sp = FullStatePersistence() 

107 result = await my_graph.run(Float2String(3.14159), persistence=sp) 

108 # len('3.14159') == 7, 21 * 2 == 42 

109 assert result.output == 42 

110 assert sp.history == snapshot( 

111 [ 

112 NodeSnapshot( 

113 state=None, 

114 node=Float2String(input_data=3.14159), 

115 start_ts=IsNow(tz=timezone.utc), 

116 status='success', 

117 id='Float2String:5', 

118 duration=IsFloat(), 

119 ), 

120 NodeSnapshot( 

121 state=None, 

122 node=String2Length(input_data='3.14159'), 

123 start_ts=IsNow(tz=timezone.utc), 

124 status='success', 

125 id='String2Length:6', 

126 duration=IsFloat(), 

127 ), 

128 NodeSnapshot( 

129 state=None, 

130 node=Double(input_data=7), 

131 start_ts=IsNow(tz=timezone.utc), 

132 status='success', 

133 id='Double:7', 

134 duration=IsFloat(), 

135 ), 

136 NodeSnapshot( 

137 state=None, 

138 node=String2Length(input_data='xxxxxxxxxxxxxxxxxxxxx'), 

139 start_ts=IsNow(tz=timezone.utc), 

140 status='success', 

141 id='String2Length:8', 

142 duration=IsFloat(), 

143 ), 

144 NodeSnapshot( 

145 state=None, 

146 node=Double(input_data=21), 

147 start_ts=IsNow(tz=timezone.utc), 

148 status='success', 

149 id='Double:9', 

150 duration=IsFloat(), 

151 ), 

152 EndSnapshot(state=None, result=End(data=42), ts=IsNow(tz=timezone.utc), id='end:10'), 

153 ] 

154 ) 

155 assert [e.node for e in sp.history] == snapshot( 

156 [ 

157 Float2String(input_data=3.14159), 

158 String2Length(input_data='3.14159'), 

159 Double(input_data=7), 

160 String2Length(input_data='xxxxxxxxxxxxxxxxxxxxx'), 

161 Double(input_data=21), 

162 End(data=42), 

163 ] 

164 ) 

165 

166 

167def test_one_bad_node(): 

168 class Float2String(BaseNode): 

169 async def run(self, ctx: GraphRunContext) -> String2Length: 

170 raise NotImplementedError() 

171 

172 class String2Length(BaseNode[None, None, None]): # pyright: ignore[reportUnusedClass] 

173 async def run(self, ctx: GraphRunContext) -> End[None]: 

174 raise NotImplementedError() 

175 

176 with pytest.raises(GraphSetupError) as exc_info: 

177 Graph(nodes=(Float2String,)) 

178 

179 assert exc_info.value.message == snapshot( 

180 '`String2Length` is referenced by `Float2String` but not included in the graph.' 

181 ) 

182 

183 

184def test_two_bad_nodes(): 

185 class Foo(BaseNode): 

186 input_data: float 

187 

188 async def run(self, ctx: GraphRunContext) -> Union[Bar, Spam]: # noqa: UP007 

189 raise NotImplementedError() 

190 

191 class Bar(BaseNode[None, None, None]): 

192 input_data: str 

193 

194 async def run(self, ctx: GraphRunContext) -> End[None]: 

195 raise NotImplementedError() 

196 

197 class Spam(BaseNode[None, None, None]): 

198 async def run(self, ctx: GraphRunContext) -> End[None]: 

199 raise NotImplementedError() 

200 

201 with pytest.raises(GraphSetupError) as exc_info: 

202 Graph(nodes=(Foo,)) 

203 

204 assert exc_info.value.message == snapshot("""\ 

205Nodes are referenced in the graph but not included in the graph: 

206 `Bar` is referenced by `Foo` 

207 `Spam` is referenced by `Foo`\ 

208""") 

209 

210 

211def test_three_bad_nodes_separate(): 

212 class Foo(BaseNode): 

213 input_data: float 

214 

215 async def run(self, ctx: GraphRunContext) -> Eggs: 

216 raise NotImplementedError() 

217 

218 class Bar(BaseNode[None, None, None]): 

219 input_data: str 

220 

221 async def run(self, ctx: GraphRunContext) -> Eggs: 

222 raise NotImplementedError() 

223 

224 class Spam(BaseNode[None, None, None]): 

225 async def run(self, ctx: GraphRunContext) -> Eggs: 

226 raise NotImplementedError() 

227 

228 class Eggs(BaseNode[None, None, None]): 

229 async def run(self, ctx: GraphRunContext) -> End[None]: 

230 raise NotImplementedError() 

231 

232 with pytest.raises(GraphSetupError) as exc_info: 

233 Graph(nodes=(Foo, Bar, Spam)) 

234 

235 assert exc_info.value.message == snapshot( 

236 '`Eggs` is referenced by `Foo`, `Bar`, and `Spam` but not included in the graph.' 

237 ) 

238 

239 

240def test_duplicate_id(): 

241 class Foo(BaseNode): 

242 async def run(self, ctx: GraphRunContext) -> Bar: 

243 raise NotImplementedError() 

244 

245 class Bar(BaseNode[None, None, None]): 

246 async def run(self, ctx: GraphRunContext) -> End[None]: 

247 raise NotImplementedError() 

248 

249 @classmethod 

250 @cache 

251 def get_node_id(cls) -> str: 

252 return 'Foo' 

253 

254 with pytest.raises(GraphSetupError) as exc_info: 

255 Graph(nodes=(Foo, Bar)) 

256 

257 assert exc_info.value.message == snapshot(IsStr(regex='Node ID `Foo` is not unique — found on <class.+')) 

258 

259 

260async def test_run_node_not_in_graph(): 

261 @dataclass 

262 class Foo(BaseNode): 

263 async def run(self, ctx: GraphRunContext) -> Bar: 

264 return Bar() 

265 

266 @dataclass 

267 class Bar(BaseNode[None, None, None]): 

268 async def run(self, ctx: GraphRunContext) -> End[None]: 

269 return Spam() # type: ignore 

270 

271 @dataclass 

272 class Spam(BaseNode[None, None, None]): 

273 async def run(self, ctx: GraphRunContext) -> End[None]: 

274 raise NotImplementedError() 

275 

276 g = Graph(nodes=(Foo, Bar)) 

277 with pytest.raises(GraphRuntimeError) as exc_info: 

278 await g.run(Foo()) 

279 

280 assert exc_info.value.message == snapshot('Node `test_run_node_not_in_graph.<locals>.Spam()` is not in the graph.') 

281 

282 

283async def test_run_return_other(mock_snapshot_id: object): 

284 @dataclass 

285 class Foo(BaseNode): 

286 async def run(self, ctx: GraphRunContext) -> Bar: 

287 return Bar() 

288 

289 @dataclass 

290 class Bar(BaseNode[None, None, None]): 

291 async def run(self, ctx: GraphRunContext) -> End[None]: 

292 return 42 # type: ignore 

293 

294 g = Graph(nodes=(Foo, Bar)) 

295 assert g.inferred_types == (type(None), type(None)) 

296 with pytest.raises(GraphRuntimeError) as exc_info: 

297 await g.run(Foo()) 

298 

299 assert exc_info.value.message == snapshot('Invalid node return type: `int`. Expected `BaseNode` or `End`.') 

300 

301 

302async def test_iter(): 

303 my_graph = Graph(nodes=(Float2String, String2Length, Double)) 

304 assert my_graph.name is None 

305 assert my_graph.inferred_types == (type(None), int) 

306 node_reprs: list[str] = [] 

307 async with my_graph.iter(Float2String(3.14)) as graph_iter: 

308 assert repr(graph_iter) == snapshot('<GraphRun graph=my_graph>') 

309 async for node in graph_iter: 

310 node_reprs.append(repr(node)) 

311 # len('3.14') * 2 == 8 

312 assert graph_iter.result 

313 assert graph_iter.result.output == 8 

314 

315 assert node_reprs == snapshot(["String2Length(input_data='3.14')", 'Double(input_data=4)', 'End(data=8)']) 

316 

317 

318async def test_iter_next(mock_snapshot_id: object): 

319 @dataclass 

320 class Foo(BaseNode): 

321 async def run(self, ctx: GraphRunContext) -> Bar: 

322 return Bar() 

323 

324 @dataclass 

325 class Bar(BaseNode): 

326 async def run(self, ctx: GraphRunContext) -> Foo: 

327 return Foo() 

328 

329 g = Graph(nodes=(Foo, Bar)) 

330 assert g.name is None 

331 sp = FullStatePersistence() 

332 async with g.iter(Foo(), persistence=sp) as run: 

333 assert g.name == 'g' 

334 n = await run.next() 

335 assert n == Bar() 

336 assert sp.history == snapshot( 

337 [ 

338 NodeSnapshot( 

339 state=None, 

340 node=Foo(), 

341 start_ts=IsNow(tz=timezone.utc), 

342 duration=IsFloat(), 

343 status='success', 

344 id='Foo:1', 

345 ), 

346 NodeSnapshot(state=None, node=Bar(), id='Bar:2'), 

347 ] 

348 ) 

349 

350 assert isinstance(n, Bar) 

351 n2 = await run.next() 

352 assert n2 == Foo() 

353 

354 assert sp.history == snapshot( 

355 [ 

356 NodeSnapshot( 

357 state=None, 

358 node=Foo(), 

359 start_ts=IsNow(tz=timezone.utc), 

360 duration=IsFloat(), 

361 status='success', 

362 id='Foo:1', 

363 ), 

364 NodeSnapshot( 

365 state=None, 

366 node=Bar(), 

367 start_ts=IsNow(tz=timezone.utc), 

368 duration=IsFloat(), 

369 status='success', 

370 id='Bar:2', 

371 ), 

372 NodeSnapshot(state=None, node=Foo(), id='Foo:3'), 

373 ] 

374 ) 

375 

376 

377async def test_iter_next_error(mock_snapshot_id: object): 

378 @dataclass 

379 class Foo(BaseNode): 

380 async def run(self, ctx: GraphRunContext) -> Bar: 

381 return Bar() 

382 

383 @dataclass 

384 class Bar(BaseNode[None, None, None]): 

385 async def run(self, ctx: GraphRunContext) -> End[None]: 

386 return End(None) 

387 

388 g = Graph(nodes=(Foo, Bar)) 

389 sp = SimpleStatePersistence() 

390 async with g.iter(Foo(), persistence=sp) as run: 

391 n = await run.next() 

392 assert n == snapshot(Bar()) 

393 

394 assert isinstance(n, BaseNode) 

395 n = await run.next() 

396 assert n == snapshot(End(data=None)) 

397 

398 with pytest.raises(TypeError, match=r'`next` must be called with a `BaseNode` instance, got End\(data=None\).'): 

399 await run.next() 

400 

401 

402async def test_next(mock_snapshot_id: object): 

403 @dataclass 

404 class Foo(BaseNode): 

405 async def run(self, ctx: GraphRunContext) -> Bar: 

406 return Bar() 

407 

408 @dataclass 

409 class Bar(BaseNode): 

410 async def run(self, ctx: GraphRunContext) -> Foo: 

411 return Foo() 

412 

413 g = Graph(nodes=(Foo, Bar)) 

414 assert g.name is None 

415 sp = FullStatePersistence() 

416 with pytest.warns(DeprecationWarning, match='`next` is deprecated, use `async with graph.iter(...)'): 

417 n = await g.next(Foo(), persistence=sp) # pyright: ignore[reportDeprecated] 

418 assert n == Bar() 

419 assert g.name == 'g' 

420 assert sp.history == snapshot( 

421 [ 

422 NodeSnapshot( 

423 state=None, 

424 node=Foo(), 

425 start_ts=IsNow(tz=timezone.utc), 

426 duration=IsFloat(), 

427 status='success', 

428 id='Foo:1', 

429 ), 

430 NodeSnapshot(state=None, node=Bar(), id='Bar:2'), 

431 ] 

432 ) 

433 

434 

435async def test_deps(mock_snapshot_id: object): 

436 @dataclass 

437 class Deps: 

438 a: int 

439 b: int 

440 

441 @dataclass 

442 class Foo(BaseNode[None, Deps]): 

443 async def run(self, ctx: GraphRunContext[None, Deps]) -> Bar: 

444 assert isinstance(ctx.deps, Deps) 

445 return Bar() 

446 

447 @dataclass 

448 class Bar(BaseNode[None, Deps, int]): 

449 async def run(self, ctx: GraphRunContext[None, Deps]) -> End[int]: 

450 assert isinstance(ctx.deps, Deps) 

451 return End(123) 

452 

453 g = Graph(nodes=(Foo, Bar)) 

454 sp = FullStatePersistence() 

455 result = await g.run(Foo(), deps=Deps(1, 2), persistence=sp) 

456 

457 assert result.output == 123 

458 assert sp.history == snapshot( 

459 [ 

460 NodeSnapshot( 

461 state=None, 

462 node=Foo(), 

463 start_ts=IsNow(tz=timezone.utc), 

464 duration=IsFloat(), 

465 status='success', 

466 id='Foo:1', 

467 ), 

468 NodeSnapshot( 

469 state=None, 

470 node=Bar(), 

471 start_ts=IsNow(tz=timezone.utc), 

472 duration=IsFloat(), 

473 status='success', 

474 id='Bar:2', 

475 ), 

476 EndSnapshot(state=None, result=End(data=123), ts=IsNow(tz=timezone.utc), id='end:3'), 

477 ] 

478 )