Coverage for faststream / confluent / subscriber / usecase.py: 98%

95 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-08 01:48 +0000

1import logging 

2from abc import abstractmethod 

3from collections.abc import AsyncIterator, Sequence 

4from typing import ( 

5 TYPE_CHECKING, 

6 Any, 

7 Optional, 

8 cast, 

9) 

10 

11import anyio 

12from confluent_kafka import KafkaException, Message 

13from typing_extensions import override 

14 

15from faststream._internal.endpoint.subscriber import SubscriberUsecase 

16from faststream._internal.endpoint.subscriber.mixins import ConcurrentMixin, TasksMixin 

17from faststream._internal.endpoint.utils import process_msg 

18from faststream._internal.types import MsgType 

19from faststream.confluent.parser import AsyncConfluentParser 

20from faststream.confluent.publisher.fake import KafkaFakePublisher 

21from faststream.confluent.schemas import TopicPartition 

22 

23if TYPE_CHECKING: 

24 from faststream._internal.endpoint.publisher import PublisherProto 

25 from faststream._internal.endpoint.subscriber import SubscriberSpecification 

26 from faststream._internal.endpoint.subscriber.call_item import CallsCollection 

27 from faststream.confluent.configs import KafkaBrokerConfig 

28 from faststream.confluent.helpers.client import AsyncConfluentConsumer 

29 from faststream.confluent.message import KafkaMessage 

30 from faststream.message import StreamMessage 

31 

32 from .config import KafkaSubscriberConfig 

33 

34 

35class LogicSubscriber(TasksMixin, SubscriberUsecase[MsgType]): 

36 """A class to handle logic for consuming messages from Kafka.""" 

37 

38 _outer_config: "KafkaBrokerConfig" 

39 

40 group_id: str | None 

41 

42 consumer: Optional["AsyncConfluentConsumer"] 

43 parser: AsyncConfluentParser 

44 

45 def __init__( 

46 self, 

47 config: "KafkaSubscriberConfig", 

48 specification: "SubscriberSpecification[Any, Any]", 

49 calls: "CallsCollection[MsgType]", 

50 ) -> None: 

51 super().__init__(config, specification, calls) 

52 

53 self.__connection_data = config.connection_data 

54 

55 self.group_id = config.group_id 

56 

57 self._topics = config.topics 

58 self._partitions = config.partitions 

59 

60 self.consumer = None 

61 self.polling_interval = config.polling_interval 

62 

63 @property 

64 def client_id(self) -> str | None: 

65 return self._outer_config.client_id 

66 

67 @property 

68 def topics(self) -> list[str]: 

69 return [f"{self._outer_config.prefix}{t}" for t in self._topics] 

70 

71 @property 

72 def partitions(self) -> list[TopicPartition]: 

73 return [p.add_prefix(self._outer_config.prefix) for p in self._partitions] 

74 

75 @override 

76 async def start(self) -> None: 

77 """Start the consumer.""" 

78 await super().start() 

79 self.consumer = consumer = self._outer_config.builder( 

80 *self.topics, 

81 partitions=self.partitions, 

82 group_id=self.group_id, 

83 client_id=self.client_id, 

84 **self.__connection_data, 

85 ) 

86 self.parser._setup(consumer) 

87 await consumer.start() 

88 

89 self._post_start() 

90 

91 if self.calls: 

92 self.add_task(self._consume) 

93 

94 async def stop(self) -> None: 

95 await super().stop() 

96 

97 if self.consumer is not None: 

98 await self.consumer.stop() 

99 self.consumer = None 

100 

101 @override 

102 async def get_one( 

103 self, 

104 *, 

105 timeout: float = 5.0, 

106 ) -> "KafkaMessage | None": 

107 assert self.consumer, "You should start subscriber at first." 

108 assert not self.calls, ( 

109 "You can't use `get_one` method if subscriber has registered handlers." 

110 ) 

111 

112 raw_message = await self.consumer.getone(timeout=timeout) 

113 

114 context = self._outer_config.fd_config.context 

115 

116 async_parser, async_decoder = self._get_parser_and_decoder() 

117 

118 return await process_msg( # type: ignore[return-value] 

119 msg=raw_message, 

120 middlewares=( 

121 m(raw_message, context=context) for m in self._broker_middlewares 

122 ), 

123 parser=async_parser, 

124 decoder=async_decoder, 

125 ) 

126 

127 @override 

128 async def __aiter__(self) -> AsyncIterator["KafkaMessage"]: # type: ignore[override] 

129 assert self.consumer, "You should start subscriber at first." 

130 assert not self.calls, ( 

131 "You can't use iterator if subscriber has registered handlers." 

132 ) 

133 

134 context = self._outer_config.fd_config.context 

135 async_parser, async_decoder = self._get_parser_and_decoder() 

136 

137 timeout = 5.0 

138 while True: 

139 raw_message = await self.consumer.getone(timeout=timeout) 

140 

141 if raw_message is None: 141 ↛ 142line 141 didn't jump to line 142 because the condition on line 141 was never true

142 continue 

143 

144 yield cast( 

145 "KafkaMessage", 

146 await process_msg( 

147 msg=raw_message, 

148 middlewares=( 

149 m(raw_message, context=context) for m in self._broker_middlewares 

150 ), 

151 parser=async_parser, 

152 decoder=async_decoder, 

153 ), 

154 ) 

155 

156 def _make_response_publisher( 

157 self, 

158 message: "StreamMessage[Any]", 

159 ) -> Sequence["PublisherProto"]: 

160 return ( 

161 KafkaFakePublisher( 

162 self._outer_config.producer, 

163 topic=message.reply_to, 

164 ), 

165 ) 

166 

167 async def consume_one(self, msg: MsgType) -> None: 

168 await self.consume(msg) 

169 

170 @abstractmethod 

171 async def get_msg(self) -> MsgType | None: 

172 raise NotImplementedError 

173 

174 async def _consume(self) -> None: 

175 assert self.consumer, "You should start subscriber at first." 

176 

177 connected = True 

178 while self.running: 

179 try: 

180 msg = await self.get_msg() 

181 

182 except KafkaException as e: # pragma: no cover # noqa: PERF203 

183 self._log( 

184 logging.ERROR, 

185 message="Message fetch error", 

186 exc_info=e, 

187 ) 

188 

189 if connected: 

190 connected = False 

191 

192 await anyio.sleep(5) 

193 

194 else: 

195 if not connected: # pragma: no cover 

196 connected = True 

197 

198 if msg is not None: 

199 await self.consume_one(msg) 

200 

201 @property 

202 def topic_names(self) -> list[str]: 

203 topics = self.topics or (f"{p.topic}-{p.partition}" for p in self.partitions) 

204 return [f"{self._outer_config.prefix}{t}" for t in topics] 

205 

206 @staticmethod 

207 def build_log_context( 

208 message: Optional["StreamMessage[Any]"], 

209 topic: str, 

210 group_id: str | None = None, 

211 ) -> dict[str, str]: 

212 return { 

213 "topic": topic, 

214 "group_id": group_id or "", 

215 "message_id": getattr(message, "message_id", ""), 

216 } 

217 

218 

219class DefaultSubscriber(LogicSubscriber[Message]): 

220 def __init__( 

221 self, 

222 config: "KafkaSubscriberConfig", 

223 specification: "SubscriberSpecification[Any, Any]", 

224 calls: "CallsCollection[Message]", 

225 ) -> None: 

226 self.parser = AsyncConfluentParser(is_manual=not config.ack_first) 

227 config.decoder = self.parser.decode_message 

228 config.parser = self.parser.parse_message 

229 super().__init__(config, specification, calls) 

230 

231 async def get_msg(self) -> Optional["Message"]: 

232 assert self.consumer, "You should setup subscriber at first." 

233 return await self.consumer.getone(timeout=self.polling_interval) 

234 

235 def get_log_context( 

236 self, 

237 message: Optional["StreamMessage[Message]"], 

238 ) -> dict[str, str]: 

239 if message is None: 

240 topic = ",".join(self.topic_names) 

241 else: 

242 topic = message.raw_message.topic() or ",".join(self.topics) 

243 

244 return self.build_log_context( 

245 message=message, 

246 topic=topic, 

247 group_id=self.group_id, 

248 ) 

249 

250 

251class ConcurrentDefaultSubscriber(ConcurrentMixin["Message"], DefaultSubscriber): 

252 async def start(self) -> None: 

253 await super().start() 

254 self.start_consume_task() 

255 

256 async def consume_one(self, msg: "Message") -> None: 

257 await self._put_msg(msg) 

258 

259 

260class BatchSubscriber(LogicSubscriber[tuple[Message, ...]]): 

261 def __init__( 

262 self, 

263 config: "KafkaSubscriberConfig", 

264 specification: "SubscriberSpecification[Any, Any]", 

265 calls: "CallsCollection[tuple[Message, ...]]", 

266 max_records: int | None, 

267 ) -> None: 

268 self.parser = AsyncConfluentParser(is_manual=not config.ack_first) 

269 config.decoder = self.parser.decode_batch 

270 config.parser = self.parser.parse_batch 

271 super().__init__(config, specification, calls) 

272 

273 self.max_records = max_records 

274 

275 async def get_msg(self) -> tuple["Message", ...] | None: 

276 assert self.consumer, "You should setup subscriber at first." 

277 return ( 

278 await self.consumer.getmany( 

279 timeout=self.polling_interval, 

280 max_records=self.max_records, 

281 ) 

282 or None 

283 ) 

284 

285 def get_log_context( 

286 self, 

287 message: Optional["StreamMessage[tuple[Message, ...]]"], 

288 ) -> dict[str, str]: 

289 if message is None: 

290 topic = ",".join(self.topic_names) 

291 else: 

292 topic = message.raw_message[0].topic() or ",".join(self.topic_names) 

293 

294 return self.build_log_context( 

295 message=message, 

296 topic=topic, 

297 group_id=self.group_id, 

298 )