Coverage for faststream / kafka / testing.py: 88%

88 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-08 01:48 +0000

1import re 

2from collections.abc import Callable, Generator, Iterable, Iterator 

3from contextlib import ExitStack, contextmanager 

4from datetime import datetime, timezone 

5from typing import TYPE_CHECKING, Any, Optional, cast 

6from unittest.mock import AsyncMock, MagicMock 

7 

8import anyio 

9from aiokafka import ConsumerRecord 

10from typing_extensions import override 

11 

12from faststream._internal.endpoint.utils import ParserComposition 

13from faststream._internal.testing.broker import TestBroker, change_producer 

14from faststream.exceptions import SubscriberNotFound 

15from faststream.kafka import TopicPartition 

16from faststream.kafka.broker import KafkaBroker 

17from faststream.kafka.message import KafkaMessage 

18from faststream.kafka.parser import AioKafkaParser 

19from faststream.kafka.publisher.producer import AioKafkaFastProducer 

20from faststream.kafka.publisher.usecase import BatchPublisher 

21from faststream.kafka.subscriber.usecase import BatchSubscriber 

22from faststream.message import encode_message, gen_cor_id 

23 

24if TYPE_CHECKING: 

25 from fast_depends.library.serializer import SerializerProto 

26 

27 from faststream._internal.basic_types import SendableMessage 

28 from faststream.kafka.publisher.usecase import LogicPublisher 

29 from faststream.kafka.response import KafkaPublishCommand 

30 from faststream.kafka.subscriber.usecase import LogicSubscriber 

31 

32__all__ = ("TestKafkaBroker",) 

33 

34 

35class TestKafkaBroker(TestBroker[KafkaBroker]): 

36 """A class to test Kafka brokers.""" 

37 

38 @contextmanager 

39 def _patch_producer(self, broker: KafkaBroker) -> Iterator[None]: 

40 fake_producer = FakeProducer(broker) 

41 

42 with ExitStack() as es: 

43 es.enter_context( 

44 change_producer(broker.config.broker_config, fake_producer), 

45 ) 

46 yield 

47 

48 @staticmethod 

49 async def _fake_connect( # type: ignore[override] 

50 broker: KafkaBroker, 

51 *args: Any, 

52 **kwargs: Any, 

53 ) -> Callable[..., AsyncMock]: 

54 broker.config.broker_config._admin_client = AsyncMock() 

55 

56 builder = MagicMock(return_value=FakeConsumer()) 

57 broker.config.broker_config.builder = builder 

58 

59 return _fake_connection 

60 

61 @staticmethod 

62 def create_publisher_fake_subscriber( 

63 broker: KafkaBroker, 

64 publisher: "LogicPublisher", 

65 ) -> tuple["LogicSubscriber[Any]", bool]: 

66 sub: LogicSubscriber[Any] | None = None 

67 for handler in broker.subscribers: 

68 handler = cast("LogicSubscriber[Any]", handler) 

69 if _is_handler_matches(handler, publisher.topic, publisher.partition): 

70 sub = handler 

71 break 

72 

73 if sub is None: 

74 is_real = False 

75 

76 topic_name = publisher.topic 

77 

78 if publisher.partition: 78 ↛ 79line 78 didn't jump to line 79 because the condition on line 78 was never true

79 tp = TopicPartition( 

80 topic=topic_name, 

81 partition=publisher.partition, 

82 ) 

83 sub = broker.subscriber( 

84 partitions=[tp], 

85 batch=isinstance(publisher, BatchPublisher), 

86 persistent=False, 

87 ) 

88 else: 

89 sub = broker.subscriber( 

90 topic_name, 

91 batch=isinstance(publisher, BatchPublisher), 

92 persistent=False, 

93 ) 

94 else: 

95 is_real = True 

96 

97 return sub, is_real 

98 

99 

100class FakeConsumer: 

101 async def start(self) -> None: 

102 pass 

103 

104 async def stop(self) -> None: 

105 pass 

106 

107 def subscribe(self, *args: Any, **kwargs: Any) -> None: 

108 pass 

109 

110 

111class FakeProducer(AioKafkaFastProducer): 

112 """A fake Kafka producer for testing purposes. 

113 

114 This class extends AioKafkaFastProducer and is used to simulate Kafka message publishing during tests. 

115 """ 

116 

117 def __init__(self, broker: KafkaBroker) -> None: 

118 self.broker = broker 

119 

120 default = AioKafkaParser( 

121 msg_class=KafkaMessage, 

122 regex=None, 

123 ) 

124 

125 self._parser = ParserComposition(broker._parser, default.parse_message) 

126 self._decoder = ParserComposition(broker._decoder, default.decode_message) 

127 

128 def __bool__(self) -> bool: 

129 return True 

130 

131 @property 

132 def closed(self) -> bool: 

133 return False 

134 

135 @override 

136 async def publish(self, cmd: "KafkaPublishCommand") -> None: 

137 """Publish a message to the Kafka broker.""" 

138 incoming = build_message( 

139 message=cmd.body, 

140 topic=cmd.destination, 

141 key=cmd.key, 

142 partition=cmd.partition, 

143 timestamp_ms=cmd.timestamp_ms, 

144 headers=cmd.headers, 

145 correlation_id=cmd.correlation_id, 

146 reply_to=cmd.reply_to, 

147 serializer=self.broker.config.fd_config._serializer, 

148 ) 

149 

150 for handler in _find_handler( 

151 cast("list[LogicSubscriber[Any]]", self.broker.subscribers), 

152 cmd.destination, 

153 cmd.partition, 

154 ): 

155 msg_to_send = [incoming] if isinstance(handler, BatchSubscriber) else incoming 

156 

157 await self._execute_handler(msg_to_send, cmd.destination, handler) 

158 

159 @override 

160 async def request(self, cmd: "KafkaPublishCommand") -> "ConsumerRecord": 

161 incoming = build_message( 

162 message=cmd.body, 

163 topic=cmd.destination, 

164 key=cmd.key, 

165 partition=cmd.partition, 

166 timestamp_ms=cmd.timestamp_ms, 

167 headers=cmd.headers, 

168 correlation_id=cmd.correlation_id, 

169 serializer=self.broker.config.fd_config._serializer, 

170 ) 

171 

172 for handler in _find_handler( 

173 cast("list[LogicSubscriber[Any]]", self.broker.subscribers), 

174 cmd.destination, 

175 cmd.partition, 

176 ): 

177 msg_to_send = [incoming] if isinstance(handler, BatchSubscriber) else incoming 

178 

179 with anyio.fail_after(cmd.timeout): 

180 return await self._execute_handler( 

181 msg_to_send, 

182 cmd.destination, 

183 handler, 

184 ) 

185 

186 raise SubscriberNotFound 

187 

188 @override 

189 async def publish_batch( 

190 self, 

191 cmd: "KafkaPublishCommand", 

192 ) -> None: 

193 """Publish a batch of messages to the Kafka broker.""" 

194 for handler in _find_handler( 

195 cast("list[LogicSubscriber[Any]]", self.broker.subscribers), 

196 cmd.destination, 

197 cmd.partition, 

198 ): 

199 messages = ( 

200 build_message( 

201 message=message, 

202 topic=cmd.destination, 

203 partition=cmd.partition, 

204 timestamp_ms=cmd.timestamp_ms, 

205 key=cmd.key_for(message_position), 

206 headers=cmd.headers, 

207 correlation_id=cmd.correlation_id, 

208 reply_to=cmd.reply_to, 

209 serializer=self.broker.config.fd_config._serializer, 

210 ) 

211 for message_position, message in enumerate(cmd.batch_bodies) 

212 ) 

213 

214 if isinstance(handler, BatchSubscriber): 214 ↛ 218line 214 didn't jump to line 218 because the condition on line 214 was always true

215 await self._execute_handler(list(messages), cmd.destination, handler) 

216 

217 else: 

218 for m in messages: 

219 await self._execute_handler(m, cmd.destination, handler) 

220 

221 async def _execute_handler( 

222 self, 

223 msg: Any, 

224 topic: str, 

225 handler: "LogicSubscriber[Any]", 

226 ) -> "ConsumerRecord": 

227 result = await handler.process_message(msg) 

228 

229 return build_message( 

230 topic=topic, 

231 message=result.body, 

232 headers=result.headers, 

233 correlation_id=result.correlation_id, 

234 serializer=self.broker.config.fd_config._serializer, 

235 ) 

236 

237 

238def build_message( 

239 message: "SendableMessage", 

240 topic: str, 

241 partition: int | None = None, 

242 timestamp_ms: int | None = None, 

243 key: bytes | None = None, 

244 headers: dict[str, str] | None = None, 

245 correlation_id: str | None = None, 

246 *, 

247 reply_to: str = "", 

248 serializer: Optional["SerializerProto"], 

249) -> "ConsumerRecord": 

250 """Build a Kafka ConsumerRecord for a sendable message.""" 

251 msg, content_type = encode_message(message, serializer=serializer) 

252 

253 k = key or b"" 

254 

255 headers = { 

256 "content-type": content_type or "", 

257 "correlation_id": correlation_id or gen_cor_id(), 

258 **(headers or {}), 

259 } 

260 

261 if reply_to: 

262 headers["reply_to"] = headers.get("reply_to", reply_to) 

263 

264 return ConsumerRecord( 

265 value=msg, 

266 topic=topic, 

267 partition=partition or 0, 

268 key=k, 

269 serialized_key_size=len(k), 

270 serialized_value_size=len(msg), 

271 checksum=sum(msg), 

272 offset=0, 

273 headers=[(i, j.encode()) for i, j in headers.items()], 

274 timestamp_type=1, 

275 timestamp=timestamp_ms or int(datetime.now(timezone.utc).timestamp() * 1000), 

276 ) 

277 

278 

279def _fake_connection(*args: Any, **kwargs: Any) -> AsyncMock: 

280 mock = AsyncMock() 

281 mock.subscribe = MagicMock 

282 mock.assign = MagicMock 

283 return mock 

284 

285 

286def _find_handler( 

287 subscribers: Iterable["LogicSubscriber[Any]"], 

288 topic: str, 

289 partition: int | None, 

290) -> Generator["LogicSubscriber[Any]", None, None]: 

291 published_groups = set() 

292 for handler in subscribers: # pragma: no branch 

293 if _is_handler_matches(handler, topic, partition): 

294 if handler.group_id: 

295 if handler.group_id in published_groups: 

296 continue 

297 else: 

298 published_groups.add(handler.group_id) 

299 yield handler 

300 

301 

302def _is_handler_matches( 

303 handler: "LogicSubscriber[Any]", 

304 topic: str, 

305 partition: int | None, 

306) -> bool: 

307 return bool( 

308 any( 

309 p.topic == topic and (partition is None or p.partition == partition) 

310 for p in handler.partitions 

311 ) 

312 or topic in handler.topics 

313 or (handler.pattern and re.match(handler.pattern, topic)), 

314 )