Coverage for pydantic_evals/pydantic_evals/otel/span_tree.py: 99.74%
264 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-28 17:27 +0000
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-28 17:27 +0000
1from __future__ import annotations
3import re
4from collections.abc import Mapping
5from datetime import datetime, timedelta, timezone
6from functools import partial
7from textwrap import indent
8from typing import TYPE_CHECKING, Any, Callable
10from typing_extensions import NotRequired, TypedDict
12__all__ = 'SpanNode', 'SpanTree', 'SpanQuery', 'as_predicate'
14if TYPE_CHECKING: # pragma: no cover
15 # Since opentelemetry isn't a required dependency, don't actually import these at runtime
16 from opentelemetry.sdk.trace import ReadableSpan
17 from opentelemetry.trace import SpanContext
18 from opentelemetry.util.types import AttributeValue
21class SpanNode:
22 """A node in the span tree; provides references to parents/children for easy traversal and queries."""
24 def __init__(self, span: ReadableSpan):
25 self._span = span
26 # If a span has no context, it's going to cause problems. We may need to add improved handling of this scenario.
27 assert self._span.context is not None, f'{span=} has no context'
29 self.parent: SpanNode | None = None
30 self.children_by_id: dict[int, SpanNode] = {} # note: we rely on insertion order to determine child order
32 @property
33 def children(self) -> list[SpanNode]:
34 return list(self.children_by_id.values())
36 @property
37 def descendants(self) -> list[SpanNode]:
38 """Return all descendants of this node in DFS order."""
39 descendants: list[SpanNode] = []
40 stack = list(self.children)
41 while stack:
42 node = stack.pop()
43 descendants.append(node)
44 stack.extend(node.children)
45 return descendants
47 @property
48 def context(self) -> SpanContext:
49 """Return the SpanContext of the wrapped span."""
50 assert self._span.context is not None
51 return self._span.context
53 @property
54 def parent_context(self) -> SpanContext | None:
55 """Return the SpanContext of the parent of the wrapped span."""
56 return self._span.parent
58 @property
59 def span_id(self) -> int:
60 """Return the integer span_id from the SpanContext."""
61 return self.context.span_id
63 @property
64 def trace_id(self) -> int:
65 """Return the integer trace_id from the SpanContext."""
66 return self.context.trace_id
68 @property
69 def name(self) -> str:
70 """Convenience for the span's name."""
71 return self._span.name
73 @property
74 def start_timestamp(self) -> datetime:
75 """Return the span's start time as a UTC datetime, or None if not set."""
76 assert self._span.start_time is not None
77 return datetime.fromtimestamp(self._span.start_time / 1e9, tz=timezone.utc)
79 @property
80 def end_timestamp(self) -> datetime:
81 """Return the span's end time as a UTC datetime, or None if not set."""
82 assert self._span.end_time is not None
83 return datetime.fromtimestamp(self._span.end_time / 1e9, tz=timezone.utc)
85 @property
86 def duration(self) -> timedelta:
87 """Return the span's duration as a timedelta, or None if start/end not set."""
88 return self.end_timestamp - self.start_timestamp
90 @property
91 def attributes(self) -> Mapping[str, AttributeValue]:
92 # Note: It would be nice to expose the non-JSON-serialized versions of (logfire-recorded) attributes with
93 # nesting etc. This just exposes the JSON-serialized version, but doing more would be difficult.
94 return self._span.attributes or {}
96 def add_child(self, child: SpanNode) -> None:
97 """Attach a child node to this node's list of children."""
98 self.children_by_id[child.span_id] = child
99 child.parent = self
101 # -------------------------------------------------------------------------
102 # Child queries
103 # -------------------------------------------------------------------------
104 def find_children(self, predicate: Callable[[SpanNode], bool]) -> list[SpanNode]:
105 """Return all immediate children that satisfy the given predicate."""
106 return [child for child in self.children if predicate(child)]
108 def first_child(self, predicate: Callable[[SpanNode], bool]) -> SpanNode | None:
109 """Return the first immediate child that satisfies the given predicate, or None if none match."""
110 for child in self.children:
111 if predicate(child):
112 return child
113 return None
115 def any_child(self, predicate: Callable[[SpanNode], bool]) -> bool:
116 """Returns True if there is at least one child that satisfies the predicate."""
117 return self.first_child(predicate) is not None
119 # -------------------------------------------------------------------------
120 # Descendant queries (DFS)
121 # -------------------------------------------------------------------------
122 def find_descendants(self, predicate: Callable[[SpanNode], bool]) -> list[SpanNode]:
123 """Return all descendant nodes that satisfy the given predicate in DFS order."""
124 found: list[SpanNode] = []
125 stack = list(self.children)
126 while stack:
127 node = stack.pop()
128 if predicate(node):
129 found.append(node)
130 stack.extend(node.children)
131 return found
133 def first_descendant(self, predicate: Callable[[SpanNode], bool]) -> SpanNode | None:
134 """DFS: Return the first descendant (in DFS order) that satisfies the given predicate, or `None` if none match."""
135 stack = list(self.children)
136 while stack:
137 node = stack.pop()
138 if predicate(node):
139 return node
140 stack.extend(node.children)
141 return None
143 def any_descendant(self, predicate: Callable[[SpanNode], bool]) -> bool:
144 """Returns `True` if there is at least one descendant that satisfies the predicate."""
145 return self.first_descendant(predicate) is not None
147 # -------------------------------------------------------------------------
148 # Ancestor queries (DFS "up" the chain)
149 # -------------------------------------------------------------------------
150 def find_ancestors(self, predicate: Callable[[SpanNode], bool]) -> list[SpanNode]:
151 """Return all ancestors that satisfy the given predicate."""
152 found: list[SpanNode] = []
153 node = self.parent
154 while node:
155 if predicate(node): 155 ↛ 157line 155 didn't jump to line 157 because the condition on line 155 was always true
156 found.append(node)
157 node = node.parent
158 return found
160 def first_ancestor(self, predicate: Callable[[SpanNode], bool]) -> SpanNode | None:
161 """Return the closest ancestor that satisfies the given predicate, or `None` if none match."""
162 node = self.parent
163 while node:
164 if predicate(node):
165 return node
166 node = node.parent
167 return None
169 def any_ancestor(self, predicate: Callable[[SpanNode], bool]) -> bool:
170 """Returns True if any ancestor satisfies the predicate."""
171 return self.first_ancestor(predicate) is not None
173 # -------------------------------------------------------------------------
174 # Matching convenience
175 # -------------------------------------------------------------------------
176 def matches(self, name: str | None = None, attributes: dict[str, Any] | None = None) -> bool:
177 """A convenience method to see if this node's span matches certain conditions.
179 - name: exact match for the Span name
180 - attributes: dict of key->value; must match exactly.
181 """
182 if name is not None and self.name != name:
183 return False
184 if attributes:
185 span_attributes = self._span.attributes or {}
186 for attr_key, attr_val in attributes.items():
187 if span_attributes.get(attr_key) != attr_val:
188 return False
189 return True
191 # -------------------------------------------------------------------------
192 # String representation
193 # -------------------------------------------------------------------------
194 def repr_xml(
195 self,
196 include_children: bool = True,
197 include_span_id: bool = False,
198 include_trace_id: bool = False,
199 include_start_timestamp: bool = False,
200 include_duration: bool = False,
201 ) -> str:
202 """Return an XML-like string representation of the node.
204 Optionally includes children, span_id, trace_id, start_timestamp, and duration.
205 """
206 first_line_parts = [f'<SpanNode name={self.name!r}']
207 if include_span_id:
208 first_line_parts.append(f'span_id={self.span_id:016x}')
209 if include_trace_id:
210 first_line_parts.append(f'trace_id={self.trace_id:032x}')
211 if include_start_timestamp:
212 first_line_parts.append(f'start_timestamp={self.start_timestamp.isoformat()!r}')
213 if include_duration:
214 first_line_parts.append(f"duration='{self.duration}'")
216 extra_lines: list[str] = []
217 if include_children and self.children:
218 first_line_parts.append('>')
219 for child in self.children:
220 extra_lines.append(
221 indent(
222 child.repr_xml(
223 include_children=include_children,
224 include_span_id=include_span_id,
225 include_trace_id=include_trace_id,
226 include_start_timestamp=include_start_timestamp,
227 include_duration=include_duration,
228 ),
229 ' ',
230 )
231 )
232 extra_lines.append('</SpanNode>')
233 else:
234 if self.children:
235 first_line_parts.append('children=...')
236 first_line_parts.append('/>')
237 return '\n'.join([' '.join(first_line_parts), *extra_lines])
239 def __str__(self) -> str:
240 if self.children:
241 return f'<SpanNode name={self.name!r} span_id={self.span_id:016x}>...</SpanNode>'
242 else:
243 return f'<SpanNode name={self.name!r} span_id={self.span_id:016x} />'
245 def __repr__(self) -> str:
246 return self.repr_xml()
249class SpanTree:
250 """A container that builds a hierarchy of SpanNode objects from a list of finished spans.
252 You can then search or iterate the tree to make your assertions (using DFS for traversal).
253 """
255 def __init__(self, spans: list[ReadableSpan] | None = None):
256 self.nodes_by_id: dict[int, SpanNode] = {}
257 self.roots: list[SpanNode] = []
258 if spans: # pragma: no cover
259 self.add_spans(spans)
261 def add_spans(self, spans: list[ReadableSpan]) -> None:
262 """Add a list of spans to the tree, rebuilding the tree structure."""
263 for span in spans:
264 node = SpanNode(span)
265 self.nodes_by_id[node.span_id] = node
266 self._rebuild_tree()
268 def _rebuild_tree(self):
269 # Ensure spans are ordered by start_timestamp so that roots and children end up in the right order
270 nodes = list(self.nodes_by_id.values())
271 nodes.sort(key=lambda node: node.start_timestamp or datetime.min)
272 self.nodes_by_id = {node.span_id: node for node in nodes}
274 # Build the parent/child relationships
275 for node in self.nodes_by_id.values():
276 parent_ctx = node.parent_context
277 if parent_ctx is not None:
278 parent_node = self.nodes_by_id.get(parent_ctx.span_id)
279 if parent_node is not None:
280 parent_node.add_child(node)
282 # Determine the roots
283 # A node is a "root" if its parent is None or if its parent's span_id is not in the current set of spans.
284 self.roots = []
285 for node in self.nodes_by_id.values():
286 parent_ctx = node.parent_context
287 if parent_ctx is None or parent_ctx.span_id not in self.nodes_by_id:
288 self.roots.append(node)
290 def flattened(self) -> list[SpanNode]:
291 """Return a list of all nodes in the tree."""
292 return list(self.nodes_by_id.values())
294 def find_all(self, predicate: Callable[[SpanNode], bool]) -> list[SpanNode]:
295 """Find all nodes in the entire tree that match the predicate, scanning from each root in DFS order."""
296 result: list[SpanNode] = []
297 stack = self.roots[:]
298 while stack:
299 node = stack.pop()
300 if predicate(node):
301 result.append(node)
302 stack.extend(node.children)
303 return result
305 def find_first(self, predicate: Callable[[SpanNode], bool]) -> SpanNode | None:
306 """Find the first node that matches a predicate, scanning from each root in DFS order. Returns `None` if not found."""
307 stack = self.roots[:]
308 while stack:
309 node = stack.pop()
310 if predicate(node):
311 return node
312 stack.extend(node.children)
313 return None
315 def any(self, predicate: Callable[[SpanNode], bool]) -> bool:
316 """Returns True if any node in the tree matches the predicate."""
317 return self.find_first(predicate) is not None
319 def __str__(self):
320 return f'<SpanTree num_roots={len(self.roots)} total_spans={len(self.nodes_by_id)} />'
322 def repr_xml(
323 self,
324 include_children: bool = True,
325 include_span_id: bool = False,
326 include_trace_id: bool = False,
327 include_start_timestamp: bool = False,
328 include_duration: bool = False,
329 ) -> str:
330 """Return an XML-like string representation of the tree, optionally including children, span_id, trace_id, duration, and timestamps."""
331 if not self.roots:
332 return '<SpanTree />'
333 repr_parts = [
334 '<SpanTree>',
335 *[
336 indent(
337 root.repr_xml(
338 include_children=include_children,
339 include_span_id=include_span_id,
340 include_trace_id=include_trace_id,
341 include_start_timestamp=include_start_timestamp,
342 include_duration=include_duration,
343 ),
344 ' ',
345 )
346 for root in self.roots
347 ],
348 '</SpanTree>',
349 ]
350 return '\n'.join(repr_parts)
352 def __repr__(self):
353 return self.repr_xml()
356class SpanQuery(TypedDict):
357 """A serializable query for filtering SpanNodes based on various conditions.
359 All fields are optional and combined with AND logic by default.
361 Due to the presence of `__calL__`, a `SpanQuery` can be used as a predicate in `SpanTree.find_first`, etc.
362 """
364 # Individual span conditions
365 ## Name conditions
366 name_equals: NotRequired[str]
367 name_contains: NotRequired[str]
368 name_matches_regex: NotRequired[str] # regex pattern
370 ## Attribute conditions
371 has_attributes: NotRequired[dict[str, Any]]
372 has_attribute_keys: NotRequired[list[str]]
374 ## Timing conditions
375 min_duration: NotRequired[timedelta | float]
376 max_duration: NotRequired[timedelta | float]
378 # Logical combinations of conditions
379 not_: NotRequired[SpanQuery]
380 and_: NotRequired[list[SpanQuery]]
381 or_: NotRequired[list[SpanQuery]]
383 # Descendant conditions
384 some_child_has: NotRequired[SpanQuery]
385 all_children_have: NotRequired[SpanQuery]
386 no_child_has: NotRequired[SpanQuery]
387 min_child_count: NotRequired[int]
388 max_child_count: NotRequired[int]
390 some_descendant_has: NotRequired[SpanQuery]
391 all_descendants_have: NotRequired[SpanQuery]
392 no_descendant_has: NotRequired[SpanQuery]
395def as_predicate(query: SpanQuery) -> Callable[[SpanNode], bool]:
396 """Convert a SpanQuery into a callable predicate that can be used in SpanTree.find_first, etc."""
397 return partial(matches, query)
400def matches(query: SpanQuery, span: SpanNode) -> bool: # noqa C901
401 """Check if the span matches the query conditions."""
402 # Logical combinations
403 if or_ := query.get('or_'):
404 if len(query) > 1:
405 raise ValueError("Cannot combine 'or_' conditions with other conditions at the same level")
406 return any(matches(q, span) for q in or_)
407 if not_ := query.get('not_'):
408 if matches(not_, span):
409 return False
410 if and_ := query.get('and_'):
411 results = [matches(q, span) for q in and_]
412 if not all(results):
413 return False
414 # At this point, all existing ANDs and no existing ORs have passed, so it comes down to this condition
416 # Name conditions
417 if (name_equals := query.get('name_equals')) and span.name != name_equals:
418 return False
419 if (name_contains := query.get('name_contains')) and name_contains not in span.name:
420 return False
421 if (name_matches_regex := query.get('name_matches_regex')) and not re.match(name_matches_regex, span.name):
422 return False
424 # Attribute conditions
425 if (has_attributes := query.get('has_attributes')) and not all(
426 span.attributes.get(key) == value for key, value in has_attributes.items()
427 ):
428 return False
429 if (has_attributes_keys := query.get('has_attribute_keys')) and not all(
430 key in span.attributes for key in has_attributes_keys
431 ):
432 return False
434 # Timing conditions
435 if (min_duration := query.get('min_duration')) is not None and span.duration is not None:
436 if not isinstance(min_duration, timedelta):
437 min_duration = timedelta(seconds=min_duration)
438 if span.duration < min_duration:
439 return False
440 if (max_duration := query.get('max_duration')) is not None and span.duration is not None:
441 if not isinstance(max_duration, timedelta):
442 max_duration = timedelta(seconds=max_duration)
443 if span.duration > max_duration:
444 return False
446 # Children conditions
447 if (min_child_count := query.get('min_child_count')) and len(span.children) < min_child_count:
448 return False
449 if (max_child_count := query.get('max_child_count')) and len(span.children) > max_child_count:
450 return False
451 if (some_child_has := query.get('some_child_has')) and not any(
452 matches(some_child_has, child) for child in span.children
453 ):
454 return False
455 if (all_children_have := query.get('all_children_have')) and not all(
456 matches(all_children_have, child) for child in span.children
457 ):
458 return False
459 if (no_child_has := query.get('no_child_has')) and any(matches(no_child_has, child) for child in span.children):
460 return False
462 # Descendant conditions
463 if (some_descendant_has := query.get('some_descendant_has')) and not any(
464 matches(some_descendant_has, child) for child in span.descendants
465 ):
466 return False
467 if (all_descendants_have := query.get('all_descendants_have')) and not all(
468 matches(all_descendants_have, child) for child in span.descendants
469 ):
470 return False
471 if (no_descendant_has := query.get('no_descendant_has')) and any(
472 matches(no_descendant_has, child) for child in span.descendants
473 ):
474 return False
476 return True