Coverage for faststream / opentelemetry / middleware.py: 97%
149 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-08 01:48 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-08 01:48 +0000
1import time
2from collections import defaultdict
3from collections.abc import Callable
4from copy import copy
5from typing import TYPE_CHECKING, Any, Optional, cast
7from opentelemetry import baggage, context, metrics, trace
8from opentelemetry.baggage.propagation import W3CBaggagePropagator
9from opentelemetry.context import Context
10from opentelemetry.semconv.trace import SpanAttributes
11from opentelemetry.trace import Link, Span
12from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
14from faststream._internal.middlewares import BaseMiddleware
15from faststream._internal.types import BrokerMiddleware, PublishCommandType
16from faststream.opentelemetry.baggage import Baggage
17from faststream.opentelemetry.consts import (
18 ERROR_TYPE,
19 INSTRUMENTING_LIBRARY_VERSION,
20 INSTRUMENTING_MODULE_NAME,
21 MESSAGING_DESTINATION_PUBLISH_NAME,
22 OTEL_SCHEMA,
23 WITH_BATCH,
24 MessageAction,
25)
27if TYPE_CHECKING:
28 from contextvars import Token
29 from types import TracebackType
31 from opentelemetry.metrics import Meter, MeterProvider
32 from opentelemetry.trace import Tracer, TracerProvider
33 from opentelemetry.util.types import Attributes
35 from faststream._internal.basic_types import AsyncFunc, AsyncFuncAny
36 from faststream._internal.context.repository import ContextRepo
37 from faststream.message import StreamMessage
38 from faststream.opentelemetry.provider import TelemetrySettingsProvider
41_BAGGAGE_PROPAGATOR = W3CBaggagePropagator()
42_TRACE_PROPAGATOR = TraceContextTextMapPropagator()
45class TelemetryMiddleware(BrokerMiddleware[Any, PublishCommandType]):
46 __slots__ = (
47 "_meter",
48 "_metrics",
49 "_settings_provider_factory",
50 "_tracer",
51 )
53 def __init__(
54 self,
55 *,
56 settings_provider_factory: Callable[
57 [Any],
58 Optional["TelemetrySettingsProvider[Any, PublishCommandType]"],
59 ],
60 tracer_provider: Optional["TracerProvider"] = None,
61 meter_provider: Optional["MeterProvider"] = None,
62 meter: Optional["Meter"] = None,
63 include_messages_counters: bool = False,
64 ) -> None:
65 self._tracer = _get_tracer(tracer_provider)
66 self._meter = _get_meter(meter_provider, meter)
67 self._metrics = _MetricsContainer(self._meter, include_messages_counters)
68 self._settings_provider_factory = settings_provider_factory
70 def __call__(
71 self,
72 msg: Any | None,
73 /,
74 *,
75 context: "ContextRepo",
76 ) -> "BaseTelemetryMiddleware[PublishCommandType]":
77 return BaseTelemetryMiddleware[PublishCommandType](
78 msg,
79 tracer=self._tracer,
80 metrics_container=self._metrics,
81 settings_provider_factory=self._settings_provider_factory,
82 context=context,
83 )
86class _MetricsContainer:
87 __slots__ = (
88 "include_messages_counters",
89 "process_counter",
90 "process_duration",
91 "publish_counter",
92 "publish_duration",
93 )
95 def __init__(self, meter: "Meter", include_messages_counters: bool) -> None:
96 self.include_messages_counters = include_messages_counters
98 self.publish_duration = meter.create_histogram(
99 name="messaging.publish.duration",
100 unit="s",
101 description="Measures the duration of publish operation.",
102 )
103 self.process_duration = meter.create_histogram(
104 name="messaging.process.duration",
105 unit="s",
106 description="Measures the duration of process operation.",
107 )
109 if include_messages_counters:
110 self.process_counter = meter.create_counter(
111 name="messaging.process.messages",
112 unit="message",
113 description="Measures the number of processed messages.",
114 )
115 self.publish_counter = meter.create_counter(
116 name="messaging.publish.messages",
117 unit="message",
118 description="Measures the number of published messages.",
119 )
121 def observe_publish(
122 self,
123 attrs: dict[str, Any],
124 duration: float,
125 msg_count: int,
126 ) -> None:
127 self.publish_duration.record(
128 amount=duration,
129 attributes=attrs,
130 )
131 if self.include_messages_counters:
132 counter_attrs = copy(attrs)
133 counter_attrs.pop(ERROR_TYPE, None)
134 self.publish_counter.add(
135 amount=msg_count,
136 attributes=counter_attrs,
137 )
139 def observe_consume(
140 self,
141 attrs: dict[str, Any],
142 duration: float,
143 msg_count: int,
144 ) -> None:
145 self.process_duration.record(
146 amount=duration,
147 attributes=attrs,
148 )
149 if self.include_messages_counters:
150 counter_attrs = copy(attrs)
151 counter_attrs.pop(ERROR_TYPE, None)
152 self.process_counter.add(
153 amount=msg_count,
154 attributes=counter_attrs,
155 )
158class BaseTelemetryMiddleware(BaseMiddleware[PublishCommandType]):
159 def __init__(
160 self,
161 msg: Any | None,
162 /,
163 *,
164 tracer: "Tracer",
165 settings_provider_factory: Callable[
166 [Any],
167 Optional["TelemetrySettingsProvider[Any, PublishCommandType]"],
168 ],
169 metrics_container: _MetricsContainer,
170 context: "ContextRepo",
171 ) -> None:
172 super().__init__(msg, context=context)
174 self._tracer = tracer
175 self._metrics = metrics_container
176 self._current_span: Span | None = None
177 self._origin_context: Context | None = None
178 self._scope_tokens: list[tuple[str, Token[Any]]] = []
179 self.__settings_provider = settings_provider_factory(msg)
181 async def publish_scope(
182 self,
183 call_next: "AsyncFunc",
184 msg: "PublishCommandType",
185 ) -> Any:
186 if (provider := self.__settings_provider) is None: 186 ↛ 187line 186 didn't jump to line 187 because the condition on line 186 was never true
187 return await call_next(msg)
189 headers = msg.headers
190 current_context = context.get_current()
191 destination_name = provider.get_publish_destination_name(msg)
193 current_baggage: Baggage | None = self.context.get_local("baggage")
194 if current_baggage:
195 headers.update(current_baggage.to_headers())
197 trace_attributes = provider.get_publish_attrs_from_cmd(msg)
198 metrics_attributes = {
199 SpanAttributes.MESSAGING_SYSTEM: provider.messaging_system,
200 SpanAttributes.MESSAGING_DESTINATION_NAME: destination_name,
201 }
203 # NOTE: if batch with single message?
204 if (msg_count := len(msg.batch_bodies)) > 1:
205 trace_attributes[SpanAttributes.MESSAGING_BATCH_MESSAGE_COUNT] = msg_count
206 current_context = _BAGGAGE_PROPAGATOR.extract(headers, current_context)
207 _BAGGAGE_PROPAGATOR.inject(
208 headers,
209 baggage.set_baggage(WITH_BATCH, True, context=current_context),
210 )
212 if self._current_span and self._current_span.is_recording():
213 current_context = trace.set_span_in_context(
214 self._current_span,
215 current_context,
216 )
217 _TRACE_PROPAGATOR.inject(headers, context=self._origin_context)
219 else:
220 create_span = self._tracer.start_span(
221 name=_create_span_name(destination_name, MessageAction.CREATE),
222 kind=trace.SpanKind.PRODUCER,
223 attributes=trace_attributes,
224 )
225 current_context = trace.set_span_in_context(create_span)
226 _TRACE_PROPAGATOR.inject(headers, context=current_context)
227 create_span.end()
229 start_time = time.perf_counter()
231 try:
232 with self._tracer.start_as_current_span(
233 name=_create_span_name(destination_name, MessageAction.PUBLISH),
234 kind=trace.SpanKind.PRODUCER,
235 attributes=trace_attributes,
236 context=current_context,
237 ) as span:
238 span.set_attribute(
239 SpanAttributes.MESSAGING_OPERATION,
240 MessageAction.PUBLISH,
241 )
242 msg.headers = headers
243 result = await call_next(msg)
245 except Exception as e:
246 metrics_attributes[ERROR_TYPE] = type(e).__name__
247 raise
249 finally:
250 duration = time.perf_counter() - start_time
251 self._metrics.observe_publish(metrics_attributes, duration, msg_count)
253 for key, token in self._scope_tokens:
254 self.context.reset_local(key, token)
256 return result
258 async def consume_scope(
259 self,
260 call_next: "AsyncFuncAny",
261 msg: "StreamMessage[Any]",
262 ) -> Any:
263 if (provider := self.__settings_provider) is None:
264 return await call_next(msg)
266 if _is_batch_message(msg):
267 links = _get_msg_links(msg)
268 current_context = Context()
269 else:
270 links = None
271 current_context = _TRACE_PROPAGATOR.extract(msg.headers)
273 destination_name = provider.get_consume_destination_name(msg)
274 trace_attributes = provider.get_consume_attrs_from_message(msg)
275 metrics_attributes = {
276 SpanAttributes.MESSAGING_SYSTEM: provider.messaging_system,
277 MESSAGING_DESTINATION_PUBLISH_NAME: destination_name,
278 }
280 if not len(current_context):
281 create_span = self._tracer.start_span(
282 name=_create_span_name(destination_name, MessageAction.CREATE),
283 kind=trace.SpanKind.CONSUMER,
284 attributes=trace_attributes,
285 links=links,
286 )
287 current_context = trace.set_span_in_context(create_span)
288 create_span.end()
290 self._origin_context = current_context
291 start_time = time.perf_counter()
293 try:
294 with self._tracer.start_as_current_span(
295 name=_create_span_name(destination_name, MessageAction.PROCESS),
296 kind=trace.SpanKind.CONSUMER,
297 context=current_context,
298 attributes=trace_attributes,
299 end_on_exit=False,
300 ) as span:
301 span.set_attribute(
302 SpanAttributes.MESSAGING_OPERATION,
303 MessageAction.PROCESS,
304 )
305 self._current_span = span
307 self._scope_tokens.append((
308 "span",
309 self.context.set_local("span", span),
310 ))
311 self._scope_tokens.append(
312 (
313 "baggage",
314 self.context.set_local("baggage", Baggage.from_msg(msg)),
315 ),
316 )
318 new_context = trace.set_span_in_context(span, current_context)
319 token = context.attach(new_context)
320 result = await call_next(msg)
321 context.detach(token)
323 except Exception as e:
324 metrics_attributes[ERROR_TYPE] = type(e).__name__
325 raise
327 finally:
328 duration = time.perf_counter() - start_time
329 msg_count = trace_attributes.get(
330 SpanAttributes.MESSAGING_BATCH_MESSAGE_COUNT,
331 1,
332 )
333 self._metrics.observe_consume(metrics_attributes, duration, msg_count)
335 return result
337 async def after_processed(
338 self,
339 exc_type: type[BaseException] | None = None,
340 exc_val: BaseException | None = None,
341 exc_tb: Optional["TracebackType"] = None,
342 ) -> bool | None:
343 if self._current_span and self._current_span.is_recording():
344 self._current_span.end()
345 return False
348def _get_meter(
349 meter_provider: Optional["MeterProvider"] = None,
350 meter: Optional["Meter"] = None,
351) -> "Meter":
352 if meter is None: 352 ↛ 358line 352 didn't jump to line 358 because the condition on line 352 was always true
353 return metrics.get_meter(
354 __name__,
355 meter_provider=meter_provider,
356 schema_url=OTEL_SCHEMA,
357 )
358 return meter
361def _get_tracer(tracer_provider: Optional["TracerProvider"] = None) -> "Tracer":
362 return trace.get_tracer(
363 instrumenting_module_name=INSTRUMENTING_MODULE_NAME,
364 instrumenting_library_version=INSTRUMENTING_LIBRARY_VERSION,
365 tracer_provider=tracer_provider,
366 schema_url=OTEL_SCHEMA,
367 )
370def _create_span_name(destination: str, action: str) -> str:
371 return f"{destination} {action}"
374def _is_batch_message(msg: "StreamMessage[Any]") -> bool:
375 with_batch = baggage.get_baggage(
376 WITH_BATCH,
377 _BAGGAGE_PROPAGATOR.extract(msg.headers),
378 )
379 return bool(msg.batch_headers or with_batch)
382def _get_msg_links(msg: "StreamMessage[Any]") -> list[Link]:
383 if not msg.batch_headers:
384 if (span := _get_span_from_headers(msg.headers)) is not None:
385 return [Link(span.get_span_context())]
386 return []
388 links = {}
389 counter: dict[str, int] = defaultdict(lambda: 0)
391 for headers in msg.batch_headers:
392 if (correlation_id := headers.get("correlation_id")) is None:
393 continue
395 counter[correlation_id] += 1
397 if (span := _get_span_from_headers(headers)) is None:
398 continue
400 attributes = _get_link_attributes(counter[correlation_id])
402 links[correlation_id] = Link(
403 span.get_span_context(),
404 attributes=attributes,
405 )
407 return list(links.values())
410def _get_span_from_headers(headers: dict[str, Any]) -> Span | None:
411 trace_context = _TRACE_PROPAGATOR.extract(headers)
412 if not len(trace_context):
413 return None
415 return cast(
416 "Span | None",
417 next(iter(trace_context.values())),
418 )
421def _get_link_attributes(message_count: int) -> "Attributes":
422 if message_count <= 1:
423 return {}
424 return {
425 SpanAttributes.MESSAGING_BATCH_MESSAGE_COUNT: message_count,
426 }