Coverage for faststream / kafka / subscriber / usecase.py: 92%

153 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-08 01:48 +0000

1import logging 

2from abc import abstractmethod 

3from collections.abc import AsyncIterator, Callable, Sequence 

4from itertools import chain 

5from typing import TYPE_CHECKING, Any, Optional, cast 

6 

7import anyio 

8from aiokafka import ConsumerRecord, TopicPartition 

9from aiokafka.errors import ConsumerStoppedError, KafkaError, UnsupportedCodecError 

10from typing_extensions import override 

11 

12from faststream._internal.endpoint.subscriber.mixins import ConcurrentMixin, TasksMixin 

13from faststream._internal.endpoint.subscriber.usecase import SubscriberUsecase 

14from faststream._internal.endpoint.utils import process_msg 

15from faststream._internal.types import MsgType 

16from faststream._internal.utils.path import compile_path 

17from faststream.kafka.helpers import make_logging_listener 

18from faststream.kafka.message import KafkaAckableMessage, KafkaMessage, KafkaRawMessage 

19from faststream.kafka.parser import AioKafkaBatchParser, AioKafkaParser 

20from faststream.kafka.publisher.fake import KafkaFakePublisher 

21 

22if TYPE_CHECKING: 

23 from aiokafka import AIOKafkaConsumer 

24 

25 from faststream._internal.endpoint.publisher import PublisherProto 

26 from faststream._internal.endpoint.subscriber import SubscriberSpecification 

27 from faststream._internal.endpoint.subscriber.call_item import CallsCollection 

28 from faststream.kafka.configs import KafkaBrokerConfig 

29 from faststream.message import StreamMessage 

30 

31 from .config import KafkaSubscriberConfig 

32 

33 

34class LogicSubscriber(TasksMixin, SubscriberUsecase[MsgType]): 

35 """A class to handle logic for consuming messages from Kafka.""" 

36 

37 consumer: Optional["AIOKafkaConsumer"] 

38 

39 batch: bool 

40 parser: AioKafkaParser 

41 

42 _outer_config: "KafkaBrokerConfig" 

43 

44 def __init__( 

45 self, 

46 config: "KafkaSubscriberConfig", 

47 specification: "SubscriberSpecification[Any, Any]", 

48 calls: "CallsCollection[MsgType]", 

49 ) -> None: 

50 super().__init__(config, specification, calls) 

51 

52 self._topics = config.topics 

53 self._partitions = config.partitions 

54 self.group_id = config.group_id 

55 

56 self._pattern = config.pattern 

57 self._listener = config.listener 

58 self._connection_args = config.connection_args 

59 

60 self.consumer = None 

61 

62 @property 

63 def pattern(self) -> str | None: 

64 if not self._pattern: 

65 return self._pattern 

66 return f"{self._outer_config.prefix}{self._pattern}" 

67 

68 @property 

69 def topics(self) -> list[str]: 

70 return [f"{self._outer_config.prefix}{t}" for t in self._topics] 

71 

72 @property 

73 def partitions(self) -> list[TopicPartition]: 

74 return [ 

75 TopicPartition( 

76 topic=f"{self._outer_config.prefix}{p.topic}", 

77 partition=p.partition, 

78 ) 

79 for p in self._partitions 

80 ] 

81 

82 @property 

83 def builder(self) -> Callable[..., "AIOKafkaConsumer"]: 

84 return self._outer_config.builder 

85 

86 @property 

87 def client_id(self) -> str | None: 

88 return self._outer_config.client_id 

89 

90 async def start(self) -> None: 

91 """Start the consumer.""" 

92 await super().start() 

93 

94 self.consumer = consumer = self.builder( 

95 group_id=self.group_id, 

96 client_id=self.client_id, 

97 **self._connection_args, 

98 ) 

99 

100 self.parser._setup(consumer) 

101 

102 if self.topics or self.pattern: 

103 consumer.subscribe( 

104 topics=self.topics, 

105 pattern=self.pattern, 

106 listener=make_logging_listener( 

107 consumer=consumer, 

108 logger=self._outer_config.logger.logger.logger, 

109 log_extra=self.get_log_context(None), 

110 listener=self._listener, 

111 ), 

112 ) 

113 

114 elif self.partitions: 114 ↛ 117line 114 didn't jump to line 117 because the condition on line 114 was always true

115 consumer.assign(partitions=self.partitions) 

116 

117 await consumer.start() 

118 

119 self._post_start() 

120 

121 if self.calls: 

122 self.add_task(self._run_consume_loop, (self.consumer,)) 

123 

124 async def stop(self) -> None: 

125 await super().stop() 

126 

127 if self.consumer is not None: 

128 await self.consumer.stop() 

129 self.consumer = None 

130 

131 @override 

132 async def get_one( 

133 self, 

134 *, 

135 timeout: float = 5.0, 

136 ) -> "KafkaMessage | None": 

137 assert not self.calls, ( 

138 "You can't use `get_one` method if subscriber has registered handlers." 

139 ) 

140 

141 assert self.consumer, "You should start subscriber at first." 

142 

143 raw_messages = await self.consumer.getmany( 

144 timeout_ms=timeout * 1000, 

145 max_records=1, 

146 ) 

147 

148 if not raw_messages: 

149 return None 

150 

151 ((raw_message,),) = raw_messages.values() 

152 

153 context = self._outer_config.fd_config.context 

154 

155 async_parser, async_decoder = self._get_parser_and_decoder() 

156 

157 msg: KafkaMessage | None = await process_msg( # type: ignore[assignment] 

158 msg=raw_message, 

159 middlewares=( 

160 m(raw_message, context=context) for m in self._broker_middlewares 

161 ), 

162 parser=async_parser, 

163 decoder=async_decoder, 

164 ) 

165 return msg 

166 

167 @override 

168 async def __aiter__(self) -> AsyncIterator["KafkaMessage"]: # type: ignore[override] 

169 assert self.consumer, "You should start subscriber at first." 

170 assert not self.calls, ( 

171 "You can't use `get_one` method if subscriber has registered handlers." 

172 ) 

173 

174 context = self._outer_config.fd_config.context 

175 async_parser, async_decoder = self._get_parser_and_decoder() 

176 

177 async for raw_message in self.consumer: 177 ↛ exitline 177 didn't return from function '__aiter__' because the loop on line 177 didn't complete

178 msg: KafkaMessage = await process_msg( # type: ignore[assignment] 

179 msg=raw_message, 

180 middlewares=( 

181 m(raw_message, context=context) for m in self._broker_middlewares 

182 ), 

183 parser=async_parser, 

184 decoder=async_decoder, 

185 ) 

186 yield msg 

187 

188 def _make_response_publisher( 

189 self, 

190 message: "StreamMessage[Any]", 

191 ) -> Sequence["PublisherProto"]: 

192 return ( 

193 KafkaFakePublisher( 

194 self._outer_config.producer, 

195 topic=message.reply_to, 

196 ), 

197 ) 

198 

199 @abstractmethod 

200 async def get_msg(self, consumer: "AIOKafkaConsumer") -> MsgType: 

201 raise NotImplementedError 

202 

203 async def _run_consume_loop(self, consumer: "AIOKafkaConsumer") -> None: 

204 assert consumer, "You should start subscriber at first." 

205 

206 connected = True 

207 while self.running: 

208 try: 

209 msg = await self.get_msg(consumer) 

210 

211 except UnsupportedCodecError as e: # noqa: PERF203 

212 self._log( 

213 logging.ERROR, 

214 "There is no suitable compression library available. Please refer to the Kafka " 

215 "documentation for more information - " 

216 "https://aiokafka.readthedocs.io/en/stable/#installation", 

217 exc_info=e, 

218 ) 

219 await anyio.sleep(15) 

220 

221 except KafkaError as e: 

222 self._log(logging.ERROR, "Kafka error occurred", exc_info=e) 

223 

224 if connected: 

225 connected = False 

226 

227 await anyio.sleep(5) 

228 

229 except ConsumerStoppedError: 

230 return 

231 

232 else: 

233 if not connected: # pragma: no cover 

234 connected = True 

235 

236 if msg: 236 ↛ 207line 236 didn't jump to line 207 because the condition on line 236 was always true

237 await self.consume_one(msg) 

238 

239 async def consume_one(self, msg: MsgType) -> None: 

240 await self.consume(msg) 

241 

242 @property 

243 def topic_names(self) -> list[str]: 

244 if self.pattern: 

245 topics = [self.pattern] 

246 

247 elif self.topics: 

248 topics = self.topics 

249 

250 else: 

251 topics = [f"{p.topic}-{p.partition}" for p in self.partitions] 

252 

253 return topics 

254 

255 @staticmethod 

256 def build_log_context( 

257 message: Optional["StreamMessage[Any]"], 

258 topic: str, 

259 group_id: str | None = None, 

260 ) -> dict[str, str]: 

261 return { 

262 "topic": topic, 

263 "group_id": group_id or "", 

264 "message_id": getattr(message, "message_id", ""), 

265 } 

266 

267 

268class DefaultSubscriber(LogicSubscriber["ConsumerRecord"]): 

269 def __init__( 

270 self, 

271 config: "KafkaSubscriberConfig", 

272 specification: "SubscriberSpecification[Any, Any]", 

273 calls: "CallsCollection[ConsumerRecord]", 

274 ) -> None: 

275 if config.pattern: 

276 reg, pattern = compile_path( 

277 config.pattern, 

278 replace_symbol=".*", 

279 patch_regex=lambda x: x.replace(r"\*", ".*"), 

280 ) 

281 config.pattern = pattern 

282 

283 else: 

284 reg = None 

285 

286 self.parser = AioKafkaParser( 

287 msg_class=KafkaMessage if config.ack_first else KafkaAckableMessage, 

288 regex=reg, 

289 ) 

290 config.parser = self.parser.parse_message 

291 config.decoder = self.parser.decode_message 

292 super().__init__(config, specification, calls) 

293 

294 async def get_msg(self, consumer: "AIOKafkaConsumer") -> "ConsumerRecord": 

295 assert consumer, "You should setup subscriber at first." 

296 return await consumer.getone() 

297 

298 def get_log_context( 

299 self, 

300 message: Optional["StreamMessage[ConsumerRecord]"], 

301 ) -> dict[str, str]: 

302 if message is None: 

303 topic = ",".join(self.topic_names) 

304 else: 

305 topic = message.raw_message.topic 

306 

307 return self.build_log_context( 

308 message=message, 

309 topic=topic, 

310 group_id=self.group_id, 

311 ) 

312 

313 

314class BatchSubscriber(LogicSubscriber[tuple["ConsumerRecord", ...]]): 

315 def __init__( 

316 self, 

317 config: "KafkaSubscriberConfig", 

318 specification: "SubscriberSpecification[Any, Any]", 

319 calls: "CallsCollection[tuple[ConsumerRecord, ...]]", 

320 batch_timeout_ms: int, 

321 max_records: int | None, 

322 ) -> None: 

323 if config.pattern: 

324 reg, pattern = compile_path( 

325 config.pattern, 

326 replace_symbol=".*", 

327 patch_regex=lambda x: x.replace(r"\*", ".*"), 

328 ) 

329 config.pattern = pattern 

330 

331 else: 

332 reg = None 

333 

334 self.parser = AioKafkaBatchParser( 

335 msg_class=KafkaMessage if config.ack_first else KafkaAckableMessage, 

336 regex=reg, 

337 ) 

338 config.decoder = self.parser.decode_batch 

339 config.parser = self.parser.parse_batch 

340 super().__init__(config, specification, calls) 

341 

342 self.batch_timeout_ms = batch_timeout_ms 

343 self.max_records = max_records 

344 

345 async def get_msg( 

346 self, 

347 consumer: "AIOKafkaConsumer", 

348 ) -> tuple["ConsumerRecord", ...]: 

349 assert consumer, "You should setup subscriber at first." 

350 

351 messages = await consumer.getmany( 

352 timeout_ms=self.batch_timeout_ms, 

353 max_records=self.max_records, 

354 ) 

355 

356 if not messages: # pragma: no cover 

357 await anyio.sleep(self.batch_timeout_ms / 1000) 

358 return () 

359 

360 return tuple(chain(*messages.values())) 

361 

362 def get_log_context( 

363 self, 

364 message: Optional["StreamMessage[tuple[ConsumerRecord, ...]]"], 

365 ) -> dict[str, str]: 

366 if message is None: 

367 topic = ",".join(self.topic_names) 

368 else: 

369 topic = message.raw_message[0].topic 

370 

371 return self.build_log_context( 

372 message=message, 

373 topic=topic, 

374 group_id=self.group_id, 

375 ) 

376 

377 

378class ConcurrentDefaultSubscriber(ConcurrentMixin["ConsumerRecord"], DefaultSubscriber): 

379 async def start(self) -> None: 

380 await super().start() 

381 self.start_consume_task() 

382 

383 async def consume_one(self, msg: "ConsumerRecord") -> None: 

384 await self._put_msg(msg) 

385 

386 

387class ConcurrentBetweenPartitionsSubscriber(DefaultSubscriber): 

388 consumer_subgroup: list["AIOKafkaConsumer"] 

389 

390 def __init__( 

391 self, 

392 config: "KafkaSubscriberConfig", 

393 specification: "SubscriberSpecification[Any, Any]", 

394 calls: "CallsCollection[ConsumerRecord]", 

395 max_workers: int, 

396 ) -> None: 

397 super().__init__(config, specification, calls) 

398 

399 self.max_workers = max_workers 

400 self.consumer_subgroup = [] 

401 

402 async def start(self) -> None: 

403 """Start the consumer subgroup.""" 

404 await super(LogicSubscriber, self).start() 

405 

406 if self.calls: 406 ↛ 419line 406 didn't jump to line 419 because the condition on line 406 was always true

407 self.consumer_subgroup = [ 

408 self.builder( 

409 group_id=self.group_id, 

410 client_id=self.client_id, 

411 **self._connection_args, 

412 ) 

413 for _ in range(self.max_workers) 

414 ] 

415 

416 else: 

417 # We should create single consumer to support 

418 # `get_one()` and `__aiter__` methods 

419 self.consumer = self.builder( 

420 group_id=self.group_id, 

421 client_id=self.client_id, 

422 **self._connection_args, 

423 ) 

424 self.consumer_subgroup = [self.consumer] 

425 

426 # Subscribers starting should be called concurrently 

427 # to balance them correctly 

428 async with anyio.create_task_group() as tg: 

429 for c in self.consumer_subgroup: 

430 c.subscribe( 

431 topics=self.topics, 

432 listener=make_logging_listener( 

433 consumer=c, 

434 logger=self._outer_config.logger.logger.logger, 

435 log_extra=self.get_log_context(None), 

436 listener=self._listener, 

437 ), 

438 ) 

439 

440 tg.start_soon(c.start) 

441 

442 self._post_start() 

443 

444 if self.calls: 444 ↛ exitline 444 didn't return from function 'start' because the condition on line 444 was always true

445 for c in self.consumer_subgroup: 

446 self.add_task(self._run_consume_loop, (c,)) 

447 

448 async def stop(self) -> None: 

449 if self.consumer_subgroup: 449 ↛ 456line 449 didn't jump to line 456 because the condition on line 449 was always true

450 async with anyio.create_task_group() as tg: 

451 for consumer in self.consumer_subgroup: 

452 tg.start_soon(consumer.stop) 

453 

454 self.consumer_subgroup = [] 

455 

456 await super().stop() 

457 

458 async def get_msg(self, consumer: "AIOKafkaConsumer") -> "KafkaRawMessage": 

459 assert consumer, "You should setup subscriber at first." 

460 message = await consumer.getone() 

461 message.consumer = consumer 

462 return cast("KafkaRawMessage", message)