Coverage for faststream / opentelemetry / middleware.py: 97%

149 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-08 01:48 +0000

1import time 

2from collections import defaultdict 

3from collections.abc import Callable 

4from copy import copy 

5from typing import TYPE_CHECKING, Any, Optional, cast 

6 

7from opentelemetry import baggage, context, metrics, trace 

8from opentelemetry.baggage.propagation import W3CBaggagePropagator 

9from opentelemetry.context import Context 

10from opentelemetry.semconv.trace import SpanAttributes 

11from opentelemetry.trace import Link, Span 

12from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator 

13 

14from faststream._internal.middlewares import BaseMiddleware 

15from faststream._internal.types import BrokerMiddleware, PublishCommandType 

16from faststream.opentelemetry.baggage import Baggage 

17from faststream.opentelemetry.consts import ( 

18 ERROR_TYPE, 

19 INSTRUMENTING_LIBRARY_VERSION, 

20 INSTRUMENTING_MODULE_NAME, 

21 MESSAGING_DESTINATION_PUBLISH_NAME, 

22 OTEL_SCHEMA, 

23 WITH_BATCH, 

24 MessageAction, 

25) 

26 

27if TYPE_CHECKING: 

28 from contextvars import Token 

29 from types import TracebackType 

30 

31 from opentelemetry.metrics import Meter, MeterProvider 

32 from opentelemetry.trace import Tracer, TracerProvider 

33 from opentelemetry.util.types import Attributes 

34 

35 from faststream._internal.basic_types import AsyncFunc, AsyncFuncAny 

36 from faststream._internal.context.repository import ContextRepo 

37 from faststream.message import StreamMessage 

38 from faststream.opentelemetry.provider import TelemetrySettingsProvider 

39 

40 

41_BAGGAGE_PROPAGATOR = W3CBaggagePropagator() 

42_TRACE_PROPAGATOR = TraceContextTextMapPropagator() 

43 

44 

45class TelemetryMiddleware(BrokerMiddleware[Any, PublishCommandType]): 

46 __slots__ = ( 

47 "_meter", 

48 "_metrics", 

49 "_settings_provider_factory", 

50 "_tracer", 

51 ) 

52 

53 def __init__( 

54 self, 

55 *, 

56 settings_provider_factory: Callable[ 

57 [Any], 

58 Optional["TelemetrySettingsProvider[Any, PublishCommandType]"], 

59 ], 

60 tracer_provider: Optional["TracerProvider"] = None, 

61 meter_provider: Optional["MeterProvider"] = None, 

62 meter: Optional["Meter"] = None, 

63 include_messages_counters: bool = False, 

64 ) -> None: 

65 self._tracer = _get_tracer(tracer_provider) 

66 self._meter = _get_meter(meter_provider, meter) 

67 self._metrics = _MetricsContainer(self._meter, include_messages_counters) 

68 self._settings_provider_factory = settings_provider_factory 

69 

70 def __call__( 

71 self, 

72 msg: Any | None, 

73 /, 

74 *, 

75 context: "ContextRepo", 

76 ) -> "BaseTelemetryMiddleware[PublishCommandType]": 

77 return BaseTelemetryMiddleware[PublishCommandType]( 

78 msg, 

79 tracer=self._tracer, 

80 metrics_container=self._metrics, 

81 settings_provider_factory=self._settings_provider_factory, 

82 context=context, 

83 ) 

84 

85 

86class _MetricsContainer: 

87 __slots__ = ( 

88 "include_messages_counters", 

89 "process_counter", 

90 "process_duration", 

91 "publish_counter", 

92 "publish_duration", 

93 ) 

94 

95 def __init__(self, meter: "Meter", include_messages_counters: bool) -> None: 

96 self.include_messages_counters = include_messages_counters 

97 

98 self.publish_duration = meter.create_histogram( 

99 name="messaging.publish.duration", 

100 unit="s", 

101 description="Measures the duration of publish operation.", 

102 ) 

103 self.process_duration = meter.create_histogram( 

104 name="messaging.process.duration", 

105 unit="s", 

106 description="Measures the duration of process operation.", 

107 ) 

108 

109 if include_messages_counters: 

110 self.process_counter = meter.create_counter( 

111 name="messaging.process.messages", 

112 unit="message", 

113 description="Measures the number of processed messages.", 

114 ) 

115 self.publish_counter = meter.create_counter( 

116 name="messaging.publish.messages", 

117 unit="message", 

118 description="Measures the number of published messages.", 

119 ) 

120 

121 def observe_publish( 

122 self, 

123 attrs: dict[str, Any], 

124 duration: float, 

125 msg_count: int, 

126 ) -> None: 

127 self.publish_duration.record( 

128 amount=duration, 

129 attributes=attrs, 

130 ) 

131 if self.include_messages_counters: 

132 counter_attrs = copy(attrs) 

133 counter_attrs.pop(ERROR_TYPE, None) 

134 self.publish_counter.add( 

135 amount=msg_count, 

136 attributes=counter_attrs, 

137 ) 

138 

139 def observe_consume( 

140 self, 

141 attrs: dict[str, Any], 

142 duration: float, 

143 msg_count: int, 

144 ) -> None: 

145 self.process_duration.record( 

146 amount=duration, 

147 attributes=attrs, 

148 ) 

149 if self.include_messages_counters: 

150 counter_attrs = copy(attrs) 

151 counter_attrs.pop(ERROR_TYPE, None) 

152 self.process_counter.add( 

153 amount=msg_count, 

154 attributes=counter_attrs, 

155 ) 

156 

157 

158class BaseTelemetryMiddleware(BaseMiddleware[PublishCommandType]): 

159 def __init__( 

160 self, 

161 msg: Any | None, 

162 /, 

163 *, 

164 tracer: "Tracer", 

165 settings_provider_factory: Callable[ 

166 [Any], 

167 Optional["TelemetrySettingsProvider[Any, PublishCommandType]"], 

168 ], 

169 metrics_container: _MetricsContainer, 

170 context: "ContextRepo", 

171 ) -> None: 

172 super().__init__(msg, context=context) 

173 

174 self._tracer = tracer 

175 self._metrics = metrics_container 

176 self._current_span: Span | None = None 

177 self._origin_context: Context | None = None 

178 self._scope_tokens: list[tuple[str, Token[Any]]] = [] 

179 self.__settings_provider = settings_provider_factory(msg) 

180 

181 async def publish_scope( 

182 self, 

183 call_next: "AsyncFunc", 

184 msg: "PublishCommandType", 

185 ) -> Any: 

186 if (provider := self.__settings_provider) is None: 186 ↛ 187line 186 didn't jump to line 187 because the condition on line 186 was never true

187 return await call_next(msg) 

188 

189 headers = msg.headers 

190 current_context = context.get_current() 

191 destination_name = provider.get_publish_destination_name(msg) 

192 

193 current_baggage: Baggage | None = self.context.get_local("baggage") 

194 if current_baggage: 

195 headers.update(current_baggage.to_headers()) 

196 

197 trace_attributes = provider.get_publish_attrs_from_cmd(msg) 

198 metrics_attributes = { 

199 SpanAttributes.MESSAGING_SYSTEM: provider.messaging_system, 

200 SpanAttributes.MESSAGING_DESTINATION_NAME: destination_name, 

201 } 

202 

203 # NOTE: if batch with single message? 

204 if (msg_count := len(msg.batch_bodies)) > 1: 

205 trace_attributes[SpanAttributes.MESSAGING_BATCH_MESSAGE_COUNT] = msg_count 

206 current_context = _BAGGAGE_PROPAGATOR.extract(headers, current_context) 

207 _BAGGAGE_PROPAGATOR.inject( 

208 headers, 

209 baggage.set_baggage(WITH_BATCH, True, context=current_context), 

210 ) 

211 

212 if self._current_span and self._current_span.is_recording(): 

213 current_context = trace.set_span_in_context( 

214 self._current_span, 

215 current_context, 

216 ) 

217 _TRACE_PROPAGATOR.inject(headers, context=self._origin_context) 

218 

219 else: 

220 create_span = self._tracer.start_span( 

221 name=_create_span_name(destination_name, MessageAction.CREATE), 

222 kind=trace.SpanKind.PRODUCER, 

223 attributes=trace_attributes, 

224 ) 

225 current_context = trace.set_span_in_context(create_span) 

226 _TRACE_PROPAGATOR.inject(headers, context=current_context) 

227 create_span.end() 

228 

229 start_time = time.perf_counter() 

230 

231 try: 

232 with self._tracer.start_as_current_span( 

233 name=_create_span_name(destination_name, MessageAction.PUBLISH), 

234 kind=trace.SpanKind.PRODUCER, 

235 attributes=trace_attributes, 

236 context=current_context, 

237 ) as span: 

238 span.set_attribute( 

239 SpanAttributes.MESSAGING_OPERATION, 

240 MessageAction.PUBLISH, 

241 ) 

242 msg.headers = headers 

243 result = await call_next(msg) 

244 

245 except Exception as e: 

246 metrics_attributes[ERROR_TYPE] = type(e).__name__ 

247 raise 

248 

249 finally: 

250 duration = time.perf_counter() - start_time 

251 self._metrics.observe_publish(metrics_attributes, duration, msg_count) 

252 

253 for key, token in self._scope_tokens: 

254 self.context.reset_local(key, token) 

255 

256 return result 

257 

258 async def consume_scope( 

259 self, 

260 call_next: "AsyncFuncAny", 

261 msg: "StreamMessage[Any]", 

262 ) -> Any: 

263 if (provider := self.__settings_provider) is None: 

264 return await call_next(msg) 

265 

266 if _is_batch_message(msg): 

267 links = _get_msg_links(msg) 

268 current_context = Context() 

269 else: 

270 links = None 

271 current_context = _TRACE_PROPAGATOR.extract(msg.headers) 

272 

273 destination_name = provider.get_consume_destination_name(msg) 

274 trace_attributes = provider.get_consume_attrs_from_message(msg) 

275 metrics_attributes = { 

276 SpanAttributes.MESSAGING_SYSTEM: provider.messaging_system, 

277 MESSAGING_DESTINATION_PUBLISH_NAME: destination_name, 

278 } 

279 

280 if not len(current_context): 

281 create_span = self._tracer.start_span( 

282 name=_create_span_name(destination_name, MessageAction.CREATE), 

283 kind=trace.SpanKind.CONSUMER, 

284 attributes=trace_attributes, 

285 links=links, 

286 ) 

287 current_context = trace.set_span_in_context(create_span) 

288 create_span.end() 

289 

290 self._origin_context = current_context 

291 start_time = time.perf_counter() 

292 

293 try: 

294 with self._tracer.start_as_current_span( 

295 name=_create_span_name(destination_name, MessageAction.PROCESS), 

296 kind=trace.SpanKind.CONSUMER, 

297 context=current_context, 

298 attributes=trace_attributes, 

299 end_on_exit=False, 

300 ) as span: 

301 span.set_attribute( 

302 SpanAttributes.MESSAGING_OPERATION, 

303 MessageAction.PROCESS, 

304 ) 

305 self._current_span = span 

306 

307 self._scope_tokens.append(( 

308 "span", 

309 self.context.set_local("span", span), 

310 )) 

311 self._scope_tokens.append( 

312 ( 

313 "baggage", 

314 self.context.set_local("baggage", Baggage.from_msg(msg)), 

315 ), 

316 ) 

317 

318 new_context = trace.set_span_in_context(span, current_context) 

319 token = context.attach(new_context) 

320 result = await call_next(msg) 

321 context.detach(token) 

322 

323 except Exception as e: 

324 metrics_attributes[ERROR_TYPE] = type(e).__name__ 

325 raise 

326 

327 finally: 

328 duration = time.perf_counter() - start_time 

329 msg_count = trace_attributes.get( 

330 SpanAttributes.MESSAGING_BATCH_MESSAGE_COUNT, 

331 1, 

332 ) 

333 self._metrics.observe_consume(metrics_attributes, duration, msg_count) 

334 

335 return result 

336 

337 async def after_processed( 

338 self, 

339 exc_type: type[BaseException] | None = None, 

340 exc_val: BaseException | None = None, 

341 exc_tb: Optional["TracebackType"] = None, 

342 ) -> bool | None: 

343 if self._current_span and self._current_span.is_recording(): 

344 self._current_span.end() 

345 return False 

346 

347 

348def _get_meter( 

349 meter_provider: Optional["MeterProvider"] = None, 

350 meter: Optional["Meter"] = None, 

351) -> "Meter": 

352 if meter is None: 352 ↛ 358line 352 didn't jump to line 358 because the condition on line 352 was always true

353 return metrics.get_meter( 

354 __name__, 

355 meter_provider=meter_provider, 

356 schema_url=OTEL_SCHEMA, 

357 ) 

358 return meter 

359 

360 

361def _get_tracer(tracer_provider: Optional["TracerProvider"] = None) -> "Tracer": 

362 return trace.get_tracer( 

363 instrumenting_module_name=INSTRUMENTING_MODULE_NAME, 

364 instrumenting_library_version=INSTRUMENTING_LIBRARY_VERSION, 

365 tracer_provider=tracer_provider, 

366 schema_url=OTEL_SCHEMA, 

367 ) 

368 

369 

370def _create_span_name(destination: str, action: str) -> str: 

371 return f"{destination} {action}" 

372 

373 

374def _is_batch_message(msg: "StreamMessage[Any]") -> bool: 

375 with_batch = baggage.get_baggage( 

376 WITH_BATCH, 

377 _BAGGAGE_PROPAGATOR.extract(msg.headers), 

378 ) 

379 return bool(msg.batch_headers or with_batch) 

380 

381 

382def _get_msg_links(msg: "StreamMessage[Any]") -> list[Link]: 

383 if not msg.batch_headers: 

384 if (span := _get_span_from_headers(msg.headers)) is not None: 

385 return [Link(span.get_span_context())] 

386 return [] 

387 

388 links = {} 

389 counter: dict[str, int] = defaultdict(lambda: 0) 

390 

391 for headers in msg.batch_headers: 

392 if (correlation_id := headers.get("correlation_id")) is None: 

393 continue 

394 

395 counter[correlation_id] += 1 

396 

397 if (span := _get_span_from_headers(headers)) is None: 

398 continue 

399 

400 attributes = _get_link_attributes(counter[correlation_id]) 

401 

402 links[correlation_id] = Link( 

403 span.get_span_context(), 

404 attributes=attributes, 

405 ) 

406 

407 return list(links.values()) 

408 

409 

410def _get_span_from_headers(headers: dict[str, Any]) -> Span | None: 

411 trace_context = _TRACE_PROPAGATOR.extract(headers) 

412 if not len(trace_context): 

413 return None 

414 

415 return cast( 

416 "Span | None", 

417 next(iter(trace_context.values())), 

418 ) 

419 

420 

421def _get_link_attributes(message_count: int) -> "Attributes": 

422 if message_count <= 1: 

423 return {} 

424 return { 

425 SpanAttributes.MESSAGING_BATCH_MESSAGE_COUNT: message_count, 

426 }