Coverage for pydantic_evals/pydantic_evals/otel/span_tree.py: 99.74%

264 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-28 17:27 +0000

1from __future__ import annotations 

2 

3import re 

4from collections.abc import Mapping 

5from datetime import datetime, timedelta, timezone 

6from functools import partial 

7from textwrap import indent 

8from typing import TYPE_CHECKING, Any, Callable 

9 

10from typing_extensions import NotRequired, TypedDict 

11 

12__all__ = 'SpanNode', 'SpanTree', 'SpanQuery', 'as_predicate' 

13 

14if TYPE_CHECKING: # pragma: no cover 

15 # Since opentelemetry isn't a required dependency, don't actually import these at runtime 

16 from opentelemetry.sdk.trace import ReadableSpan 

17 from opentelemetry.trace import SpanContext 

18 from opentelemetry.util.types import AttributeValue 

19 

20 

21class SpanNode: 

22 """A node in the span tree; provides references to parents/children for easy traversal and queries.""" 

23 

24 def __init__(self, span: ReadableSpan): 

25 self._span = span 

26 # If a span has no context, it's going to cause problems. We may need to add improved handling of this scenario. 

27 assert self._span.context is not None, f'{span=} has no context' 

28 

29 self.parent: SpanNode | None = None 

30 self.children_by_id: dict[int, SpanNode] = {} # note: we rely on insertion order to determine child order 

31 

32 @property 

33 def children(self) -> list[SpanNode]: 

34 return list(self.children_by_id.values()) 

35 

36 @property 

37 def descendants(self) -> list[SpanNode]: 

38 """Return all descendants of this node in DFS order.""" 

39 descendants: list[SpanNode] = [] 

40 stack = list(self.children) 

41 while stack: 

42 node = stack.pop() 

43 descendants.append(node) 

44 stack.extend(node.children) 

45 return descendants 

46 

47 @property 

48 def context(self) -> SpanContext: 

49 """Return the SpanContext of the wrapped span.""" 

50 assert self._span.context is not None 

51 return self._span.context 

52 

53 @property 

54 def parent_context(self) -> SpanContext | None: 

55 """Return the SpanContext of the parent of the wrapped span.""" 

56 return self._span.parent 

57 

58 @property 

59 def span_id(self) -> int: 

60 """Return the integer span_id from the SpanContext.""" 

61 return self.context.span_id 

62 

63 @property 

64 def trace_id(self) -> int: 

65 """Return the integer trace_id from the SpanContext.""" 

66 return self.context.trace_id 

67 

68 @property 

69 def name(self) -> str: 

70 """Convenience for the span's name.""" 

71 return self._span.name 

72 

73 @property 

74 def start_timestamp(self) -> datetime: 

75 """Return the span's start time as a UTC datetime, or None if not set.""" 

76 assert self._span.start_time is not None 

77 return datetime.fromtimestamp(self._span.start_time / 1e9, tz=timezone.utc) 

78 

79 @property 

80 def end_timestamp(self) -> datetime: 

81 """Return the span's end time as a UTC datetime, or None if not set.""" 

82 assert self._span.end_time is not None 

83 return datetime.fromtimestamp(self._span.end_time / 1e9, tz=timezone.utc) 

84 

85 @property 

86 def duration(self) -> timedelta: 

87 """Return the span's duration as a timedelta, or None if start/end not set.""" 

88 return self.end_timestamp - self.start_timestamp 

89 

90 @property 

91 def attributes(self) -> Mapping[str, AttributeValue]: 

92 # Note: It would be nice to expose the non-JSON-serialized versions of (logfire-recorded) attributes with 

93 # nesting etc. This just exposes the JSON-serialized version, but doing more would be difficult. 

94 return self._span.attributes or {} 

95 

96 def add_child(self, child: SpanNode) -> None: 

97 """Attach a child node to this node's list of children.""" 

98 self.children_by_id[child.span_id] = child 

99 child.parent = self 

100 

101 # ------------------------------------------------------------------------- 

102 # Child queries 

103 # ------------------------------------------------------------------------- 

104 def find_children(self, predicate: Callable[[SpanNode], bool]) -> list[SpanNode]: 

105 """Return all immediate children that satisfy the given predicate.""" 

106 return [child for child in self.children if predicate(child)] 

107 

108 def first_child(self, predicate: Callable[[SpanNode], bool]) -> SpanNode | None: 

109 """Return the first immediate child that satisfies the given predicate, or None if none match.""" 

110 for child in self.children: 

111 if predicate(child): 

112 return child 

113 return None 

114 

115 def any_child(self, predicate: Callable[[SpanNode], bool]) -> bool: 

116 """Returns True if there is at least one child that satisfies the predicate.""" 

117 return self.first_child(predicate) is not None 

118 

119 # ------------------------------------------------------------------------- 

120 # Descendant queries (DFS) 

121 # ------------------------------------------------------------------------- 

122 def find_descendants(self, predicate: Callable[[SpanNode], bool]) -> list[SpanNode]: 

123 """Return all descendant nodes that satisfy the given predicate in DFS order.""" 

124 found: list[SpanNode] = [] 

125 stack = list(self.children) 

126 while stack: 

127 node = stack.pop() 

128 if predicate(node): 

129 found.append(node) 

130 stack.extend(node.children) 

131 return found 

132 

133 def first_descendant(self, predicate: Callable[[SpanNode], bool]) -> SpanNode | None: 

134 """DFS: Return the first descendant (in DFS order) that satisfies the given predicate, or `None` if none match.""" 

135 stack = list(self.children) 

136 while stack: 

137 node = stack.pop() 

138 if predicate(node): 

139 return node 

140 stack.extend(node.children) 

141 return None 

142 

143 def any_descendant(self, predicate: Callable[[SpanNode], bool]) -> bool: 

144 """Returns `True` if there is at least one descendant that satisfies the predicate.""" 

145 return self.first_descendant(predicate) is not None 

146 

147 # ------------------------------------------------------------------------- 

148 # Ancestor queries (DFS "up" the chain) 

149 # ------------------------------------------------------------------------- 

150 def find_ancestors(self, predicate: Callable[[SpanNode], bool]) -> list[SpanNode]: 

151 """Return all ancestors that satisfy the given predicate.""" 

152 found: list[SpanNode] = [] 

153 node = self.parent 

154 while node: 

155 if predicate(node): 155 ↛ 157line 155 didn't jump to line 157 because the condition on line 155 was always true

156 found.append(node) 

157 node = node.parent 

158 return found 

159 

160 def first_ancestor(self, predicate: Callable[[SpanNode], bool]) -> SpanNode | None: 

161 """Return the closest ancestor that satisfies the given predicate, or `None` if none match.""" 

162 node = self.parent 

163 while node: 

164 if predicate(node): 

165 return node 

166 node = node.parent 

167 return None 

168 

169 def any_ancestor(self, predicate: Callable[[SpanNode], bool]) -> bool: 

170 """Returns True if any ancestor satisfies the predicate.""" 

171 return self.first_ancestor(predicate) is not None 

172 

173 # ------------------------------------------------------------------------- 

174 # Matching convenience 

175 # ------------------------------------------------------------------------- 

176 def matches(self, name: str | None = None, attributes: dict[str, Any] | None = None) -> bool: 

177 """A convenience method to see if this node's span matches certain conditions. 

178 

179 - name: exact match for the Span name 

180 - attributes: dict of key->value; must match exactly. 

181 """ 

182 if name is not None and self.name != name: 

183 return False 

184 if attributes: 

185 span_attributes = self._span.attributes or {} 

186 for attr_key, attr_val in attributes.items(): 

187 if span_attributes.get(attr_key) != attr_val: 

188 return False 

189 return True 

190 

191 # ------------------------------------------------------------------------- 

192 # String representation 

193 # ------------------------------------------------------------------------- 

194 def repr_xml( 

195 self, 

196 include_children: bool = True, 

197 include_span_id: bool = False, 

198 include_trace_id: bool = False, 

199 include_start_timestamp: bool = False, 

200 include_duration: bool = False, 

201 ) -> str: 

202 """Return an XML-like string representation of the node. 

203 

204 Optionally includes children, span_id, trace_id, start_timestamp, and duration. 

205 """ 

206 first_line_parts = [f'<SpanNode name={self.name!r}'] 

207 if include_span_id: 

208 first_line_parts.append(f'span_id={self.span_id:016x}') 

209 if include_trace_id: 

210 first_line_parts.append(f'trace_id={self.trace_id:032x}') 

211 if include_start_timestamp: 

212 first_line_parts.append(f'start_timestamp={self.start_timestamp.isoformat()!r}') 

213 if include_duration: 

214 first_line_parts.append(f"duration='{self.duration}'") 

215 

216 extra_lines: list[str] = [] 

217 if include_children and self.children: 

218 first_line_parts.append('>') 

219 for child in self.children: 

220 extra_lines.append( 

221 indent( 

222 child.repr_xml( 

223 include_children=include_children, 

224 include_span_id=include_span_id, 

225 include_trace_id=include_trace_id, 

226 include_start_timestamp=include_start_timestamp, 

227 include_duration=include_duration, 

228 ), 

229 ' ', 

230 ) 

231 ) 

232 extra_lines.append('</SpanNode>') 

233 else: 

234 if self.children: 

235 first_line_parts.append('children=...') 

236 first_line_parts.append('/>') 

237 return '\n'.join([' '.join(first_line_parts), *extra_lines]) 

238 

239 def __str__(self) -> str: 

240 if self.children: 

241 return f'<SpanNode name={self.name!r} span_id={self.span_id:016x}>...</SpanNode>' 

242 else: 

243 return f'<SpanNode name={self.name!r} span_id={self.span_id:016x} />' 

244 

245 def __repr__(self) -> str: 

246 return self.repr_xml() 

247 

248 

249class SpanTree: 

250 """A container that builds a hierarchy of SpanNode objects from a list of finished spans. 

251 

252 You can then search or iterate the tree to make your assertions (using DFS for traversal). 

253 """ 

254 

255 def __init__(self, spans: list[ReadableSpan] | None = None): 

256 self.nodes_by_id: dict[int, SpanNode] = {} 

257 self.roots: list[SpanNode] = [] 

258 if spans: # pragma: no cover 

259 self.add_spans(spans) 

260 

261 def add_spans(self, spans: list[ReadableSpan]) -> None: 

262 """Add a list of spans to the tree, rebuilding the tree structure.""" 

263 for span in spans: 

264 node = SpanNode(span) 

265 self.nodes_by_id[node.span_id] = node 

266 self._rebuild_tree() 

267 

268 def _rebuild_tree(self): 

269 # Ensure spans are ordered by start_timestamp so that roots and children end up in the right order 

270 nodes = list(self.nodes_by_id.values()) 

271 nodes.sort(key=lambda node: node.start_timestamp or datetime.min) 

272 self.nodes_by_id = {node.span_id: node for node in nodes} 

273 

274 # Build the parent/child relationships 

275 for node in self.nodes_by_id.values(): 

276 parent_ctx = node.parent_context 

277 if parent_ctx is not None: 

278 parent_node = self.nodes_by_id.get(parent_ctx.span_id) 

279 if parent_node is not None: 

280 parent_node.add_child(node) 

281 

282 # Determine the roots 

283 # A node is a "root" if its parent is None or if its parent's span_id is not in the current set of spans. 

284 self.roots = [] 

285 for node in self.nodes_by_id.values(): 

286 parent_ctx = node.parent_context 

287 if parent_ctx is None or parent_ctx.span_id not in self.nodes_by_id: 

288 self.roots.append(node) 

289 

290 def flattened(self) -> list[SpanNode]: 

291 """Return a list of all nodes in the tree.""" 

292 return list(self.nodes_by_id.values()) 

293 

294 def find_all(self, predicate: Callable[[SpanNode], bool]) -> list[SpanNode]: 

295 """Find all nodes in the entire tree that match the predicate, scanning from each root in DFS order.""" 

296 result: list[SpanNode] = [] 

297 stack = self.roots[:] 

298 while stack: 

299 node = stack.pop() 

300 if predicate(node): 

301 result.append(node) 

302 stack.extend(node.children) 

303 return result 

304 

305 def find_first(self, predicate: Callable[[SpanNode], bool]) -> SpanNode | None: 

306 """Find the first node that matches a predicate, scanning from each root in DFS order. Returns `None` if not found.""" 

307 stack = self.roots[:] 

308 while stack: 

309 node = stack.pop() 

310 if predicate(node): 

311 return node 

312 stack.extend(node.children) 

313 return None 

314 

315 def any(self, predicate: Callable[[SpanNode], bool]) -> bool: 

316 """Returns True if any node in the tree matches the predicate.""" 

317 return self.find_first(predicate) is not None 

318 

319 def __str__(self): 

320 return f'<SpanTree num_roots={len(self.roots)} total_spans={len(self.nodes_by_id)} />' 

321 

322 def repr_xml( 

323 self, 

324 include_children: bool = True, 

325 include_span_id: bool = False, 

326 include_trace_id: bool = False, 

327 include_start_timestamp: bool = False, 

328 include_duration: bool = False, 

329 ) -> str: 

330 """Return an XML-like string representation of the tree, optionally including children, span_id, trace_id, duration, and timestamps.""" 

331 if not self.roots: 

332 return '<SpanTree />' 

333 repr_parts = [ 

334 '<SpanTree>', 

335 *[ 

336 indent( 

337 root.repr_xml( 

338 include_children=include_children, 

339 include_span_id=include_span_id, 

340 include_trace_id=include_trace_id, 

341 include_start_timestamp=include_start_timestamp, 

342 include_duration=include_duration, 

343 ), 

344 ' ', 

345 ) 

346 for root in self.roots 

347 ], 

348 '</SpanTree>', 

349 ] 

350 return '\n'.join(repr_parts) 

351 

352 def __repr__(self): 

353 return self.repr_xml() 

354 

355 

356class SpanQuery(TypedDict): 

357 """A serializable query for filtering SpanNodes based on various conditions. 

358 

359 All fields are optional and combined with AND logic by default. 

360 

361 Due to the presence of `__calL__`, a `SpanQuery` can be used as a predicate in `SpanTree.find_first`, etc. 

362 """ 

363 

364 # Individual span conditions 

365 ## Name conditions 

366 name_equals: NotRequired[str] 

367 name_contains: NotRequired[str] 

368 name_matches_regex: NotRequired[str] # regex pattern 

369 

370 ## Attribute conditions 

371 has_attributes: NotRequired[dict[str, Any]] 

372 has_attribute_keys: NotRequired[list[str]] 

373 

374 ## Timing conditions 

375 min_duration: NotRequired[timedelta | float] 

376 max_duration: NotRequired[timedelta | float] 

377 

378 # Logical combinations of conditions 

379 not_: NotRequired[SpanQuery] 

380 and_: NotRequired[list[SpanQuery]] 

381 or_: NotRequired[list[SpanQuery]] 

382 

383 # Descendant conditions 

384 some_child_has: NotRequired[SpanQuery] 

385 all_children_have: NotRequired[SpanQuery] 

386 no_child_has: NotRequired[SpanQuery] 

387 min_child_count: NotRequired[int] 

388 max_child_count: NotRequired[int] 

389 

390 some_descendant_has: NotRequired[SpanQuery] 

391 all_descendants_have: NotRequired[SpanQuery] 

392 no_descendant_has: NotRequired[SpanQuery] 

393 

394 

395def as_predicate(query: SpanQuery) -> Callable[[SpanNode], bool]: 

396 """Convert a SpanQuery into a callable predicate that can be used in SpanTree.find_first, etc.""" 

397 return partial(matches, query) 

398 

399 

400def matches(query: SpanQuery, span: SpanNode) -> bool: # noqa C901 

401 """Check if the span matches the query conditions.""" 

402 # Logical combinations 

403 if or_ := query.get('or_'): 

404 if len(query) > 1: 

405 raise ValueError("Cannot combine 'or_' conditions with other conditions at the same level") 

406 return any(matches(q, span) for q in or_) 

407 if not_ := query.get('not_'): 

408 if matches(not_, span): 

409 return False 

410 if and_ := query.get('and_'): 

411 results = [matches(q, span) for q in and_] 

412 if not all(results): 

413 return False 

414 # At this point, all existing ANDs and no existing ORs have passed, so it comes down to this condition 

415 

416 # Name conditions 

417 if (name_equals := query.get('name_equals')) and span.name != name_equals: 

418 return False 

419 if (name_contains := query.get('name_contains')) and name_contains not in span.name: 

420 return False 

421 if (name_matches_regex := query.get('name_matches_regex')) and not re.match(name_matches_regex, span.name): 

422 return False 

423 

424 # Attribute conditions 

425 if (has_attributes := query.get('has_attributes')) and not all( 

426 span.attributes.get(key) == value for key, value in has_attributes.items() 

427 ): 

428 return False 

429 if (has_attributes_keys := query.get('has_attribute_keys')) and not all( 

430 key in span.attributes for key in has_attributes_keys 

431 ): 

432 return False 

433 

434 # Timing conditions 

435 if (min_duration := query.get('min_duration')) is not None and span.duration is not None: 

436 if not isinstance(min_duration, timedelta): 

437 min_duration = timedelta(seconds=min_duration) 

438 if span.duration < min_duration: 

439 return False 

440 if (max_duration := query.get('max_duration')) is not None and span.duration is not None: 

441 if not isinstance(max_duration, timedelta): 

442 max_duration = timedelta(seconds=max_duration) 

443 if span.duration > max_duration: 

444 return False 

445 

446 # Children conditions 

447 if (min_child_count := query.get('min_child_count')) and len(span.children) < min_child_count: 

448 return False 

449 if (max_child_count := query.get('max_child_count')) and len(span.children) > max_child_count: 

450 return False 

451 if (some_child_has := query.get('some_child_has')) and not any( 

452 matches(some_child_has, child) for child in span.children 

453 ): 

454 return False 

455 if (all_children_have := query.get('all_children_have')) and not all( 

456 matches(all_children_have, child) for child in span.children 

457 ): 

458 return False 

459 if (no_child_has := query.get('no_child_has')) and any(matches(no_child_has, child) for child in span.children): 

460 return False 

461 

462 # Descendant conditions 

463 if (some_descendant_has := query.get('some_descendant_has')) and not any( 

464 matches(some_descendant_has, child) for child in span.descendants 

465 ): 

466 return False 

467 if (all_descendants_have := query.get('all_descendants_have')) and not all( 

468 matches(all_descendants_have, child) for child in span.descendants 

469 ): 

470 return False 

471 if (no_descendant_has := query.get('no_descendant_has')) and any( 

472 matches(no_descendant_has, child) for child in span.descendants 

473 ): 

474 return False 

475 

476 return True