Coverage for faststream / confluent / testing.py: 86%

111 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-08 01:48 +0000

1from collections.abc import Callable, Generator, Iterable, Iterator 

2from contextlib import ExitStack, contextmanager 

3from datetime import datetime, timezone 

4from typing import TYPE_CHECKING, Any, Optional, cast 

5from unittest.mock import AsyncMock, MagicMock 

6 

7import anyio 

8from typing_extensions import override 

9 

10from faststream._internal.endpoint.utils import ParserComposition 

11from faststream._internal.testing.broker import TestBroker, change_producer 

12from faststream.confluent.broker import KafkaBroker 

13from faststream.confluent.parser import AsyncConfluentParser 

14from faststream.confluent.publisher.producer import AsyncConfluentFastProducer 

15from faststream.confluent.publisher.usecase import BatchPublisher 

16from faststream.confluent.schemas import TopicPartition 

17from faststream.confluent.subscriber.usecase import BatchSubscriber 

18from faststream.exceptions import SubscriberNotFound 

19from faststream.message import encode_message, gen_cor_id 

20 

21if TYPE_CHECKING: 

22 from fast_depends.library.serializer import SerializerProto 

23 

24 from faststream._internal.basic_types import SendableMessage 

25 from faststream.confluent.publisher.usecase import LogicPublisher 

26 from faststream.confluent.response import KafkaPublishCommand 

27 from faststream.confluent.subscriber.usecase import LogicSubscriber 

28 

29 

30__all__ = ("TestKafkaBroker",) 

31 

32 

33class TestKafkaBroker(TestBroker[KafkaBroker]): 

34 """A class to test Kafka brokers.""" 

35 

36 @contextmanager 

37 def _patch_producer(self, broker: KafkaBroker) -> Iterator[None]: 

38 fake_producer = FakeProducer(broker) 

39 

40 with ExitStack() as es: 

41 es.enter_context( 

42 change_producer(broker.config.broker_config, fake_producer), 

43 ) 

44 yield 

45 

46 @staticmethod 

47 async def _fake_connect( # type: ignore[override] 

48 broker: KafkaBroker, 

49 *args: Any, 

50 **kwargs: Any, 

51 ) -> Callable[..., AsyncMock]: 

52 broker.config.broker_config.admin.admin_client = MagicMock() 

53 return _fake_connection 

54 

55 @staticmethod 

56 def create_publisher_fake_subscriber( 

57 broker: KafkaBroker, 

58 publisher: "LogicPublisher", 

59 ) -> tuple["LogicSubscriber[Any]", bool]: 

60 sub: LogicSubscriber[Any] | None = None 

61 for handler in broker.subscribers: 

62 handler = cast("LogicSubscriber[Any]", handler) 

63 if _is_handler_matches( 

64 handler, 

65 topic=publisher.topic, 

66 partition=publisher.partition, 

67 ): 

68 sub = handler 

69 break 

70 

71 if sub is None: 

72 is_real = False 

73 

74 topic_name = publisher.topic 

75 

76 if publisher.partition: 76 ↛ 77line 76 didn't jump to line 77 because the condition on line 76 was never true

77 tp = TopicPartition( 

78 topic=topic_name, 

79 partition=publisher.partition, 

80 ) 

81 sub = broker.subscriber( 

82 partitions=[tp], 

83 batch=isinstance(publisher, BatchPublisher), 

84 auto_offset_reset="earliest", 

85 persistent=False, 

86 ) 

87 else: 

88 sub = broker.subscriber( 

89 topic_name, 

90 batch=isinstance(publisher, BatchPublisher), 

91 auto_offset_reset="earliest", 

92 persistent=False, 

93 ) 

94 else: 

95 is_real = True 

96 

97 return sub, is_real 

98 

99 

100class FakeProducer(AsyncConfluentFastProducer): 

101 """A fake Kafka producer for testing purposes. 

102 

103 This class extends AsyncConfluentFastProducer and is used to simulate Kafka message publishing during tests. 

104 """ 

105 

106 def __init__(self, broker: KafkaBroker) -> None: 

107 self.broker = broker 

108 

109 default = AsyncConfluentParser() 

110 self._parser = ParserComposition(broker._parser, default.parse_message) 

111 self._decoder = ParserComposition(broker._decoder, default.decode_message) 

112 

113 def __bool__(self) -> bool: 

114 return True 

115 

116 async def ping(self, timeout: float) -> bool: 

117 return True 

118 

119 @override 

120 async def publish(self, cmd: "KafkaPublishCommand") -> None: 

121 """Publish a message to the Kafka broker.""" 

122 incoming = build_message( 

123 message=cmd.body, 

124 topic=cmd.destination, 

125 key=cmd.key, 

126 partition=cmd.partition, 

127 timestamp_ms=cmd.timestamp_ms, 

128 headers=cmd.headers, 

129 correlation_id=cmd.correlation_id, 

130 reply_to=cmd.reply_to, 

131 serializer=self.broker.config.fd_config._serializer, 

132 ) 

133 

134 for handler in _find_handler( 

135 cast("Iterable[LogicSubscriber[Any]]", self.broker.subscribers), 

136 cmd.destination, 

137 cmd.partition, 

138 ): 

139 msg_to_send = [incoming] if isinstance(handler, BatchSubscriber) else incoming 

140 

141 await self._execute_handler(msg_to_send, cmd.destination, handler) 

142 

143 @override 

144 async def publish_batch(self, cmd: "KafkaPublishCommand") -> None: 

145 """Publish a batch of messages to the Kafka broker.""" 

146 for handler in _find_handler( 

147 cast("Iterable[LogicSubscriber[Any]]", self.broker.subscribers), 

148 cmd.destination, 

149 cmd.partition, 

150 ): 

151 messages = ( 

152 build_message( 

153 message=message, 

154 topic=cmd.destination, 

155 partition=cmd.partition, 

156 timestamp_ms=cmd.timestamp_ms, 

157 key=cmd.key_for(message_position), 

158 headers=cmd.headers, 

159 correlation_id=cmd.correlation_id, 

160 reply_to=cmd.reply_to, 

161 serializer=self.broker.config.fd_config._serializer, 

162 ) 

163 for message_position, message in enumerate(cmd.batch_bodies) 

164 ) 

165 

166 if isinstance(handler, BatchSubscriber): 166 ↛ 170line 166 didn't jump to line 170 because the condition on line 166 was always true

167 await self._execute_handler(list(messages), cmd.destination, handler) 

168 

169 else: 

170 for m in messages: 

171 await self._execute_handler(m, cmd.destination, handler) 

172 

173 @override 

174 async def request(self, cmd: "KafkaPublishCommand") -> "MockConfluentMessage": 

175 incoming = build_message( 

176 message=cmd.body, 

177 topic=cmd.destination, 

178 key=cmd.key, 

179 partition=cmd.partition, 

180 timestamp_ms=cmd.timestamp_ms, 

181 headers=cmd.headers, 

182 correlation_id=cmd.correlation_id, 

183 serializer=self.broker.config.fd_config._serializer, 

184 ) 

185 

186 for handler in _find_handler( 

187 cast("Iterable[LogicSubscriber[Any]]", self.broker.subscribers), 

188 cmd.destination, 

189 cmd.partition, 

190 ): 

191 msg_to_send = [incoming] if isinstance(handler, BatchSubscriber) else incoming 

192 

193 with anyio.fail_after(cmd.timeout): 

194 return await self._execute_handler( 

195 msg_to_send, 

196 cmd.destination, 

197 handler, 

198 ) 

199 

200 raise SubscriberNotFound 

201 

202 async def _execute_handler( 

203 self, 

204 msg: Any, 

205 topic: str, 

206 handler: "LogicSubscriber[Any]", 

207 ) -> "MockConfluentMessage": 

208 result = await handler.process_message(msg) 

209 

210 return build_message( 

211 topic=topic, 

212 message=result.body, 

213 headers=result.headers, 

214 correlation_id=result.correlation_id or gen_cor_id(), 

215 serializer=self.broker.config.fd_config._serializer, 

216 ) 

217 

218 

219class MockConfluentMessage: 

220 def __init__( 

221 self, 

222 raw_msg: bytes, 

223 topic: str, 

224 key: bytes | str, 

225 headers: list[tuple[str, bytes]], 

226 offset: int, 

227 partition: int, 

228 timestamp_type: int, 

229 timestamp_ms: int, 

230 error: str | None = None, 

231 ) -> None: 

232 self._raw_msg = raw_msg 

233 self._topic = topic 

234 

235 if isinstance(key, str): 235 ↛ 236line 235 didn't jump to line 236 because the condition on line 235 was never true

236 self._key = key.encode() 

237 else: 

238 self._key = key 

239 

240 self._headers = headers 

241 self._error = error 

242 self._offset = offset 

243 self._partition = partition 

244 self._timestamp = (timestamp_type, timestamp_ms) 

245 

246 def len(self) -> int: 

247 return len(self._raw_msg) 

248 

249 def error(self) -> str | None: 

250 return self._error 

251 

252 def headers(self) -> list[tuple[str, bytes]]: 

253 return self._headers 

254 

255 def key(self) -> bytes: 

256 return self._key 

257 

258 def offset(self) -> int: 

259 return self._offset 

260 

261 def partition(self) -> int: 

262 return self._partition 

263 

264 def timestamp(self) -> tuple[int, int]: 

265 return self._timestamp 

266 

267 def topic(self) -> str: 

268 return self._topic 

269 

270 def value(self) -> bytes: 

271 return self._raw_msg 

272 

273 

274def build_message( 

275 message: "SendableMessage", 

276 topic: str, 

277 *, 

278 correlation_id: str | None = None, 

279 partition: int | None = None, 

280 timestamp_ms: int | None = None, 

281 key: bytes | str | None = None, 

282 headers: dict[str, str] | None = None, 

283 reply_to: str = "", 

284 serializer: Optional["SerializerProto"] = None, 

285) -> MockConfluentMessage: 

286 """Build a mock confluent_kafka.Message for a sendable message.""" 

287 msg, content_type = encode_message(message, serializer) 

288 k = key or b"" 

289 headers = { 

290 "content-type": content_type or "", 

291 "correlation_id": correlation_id or gen_cor_id(), 

292 "reply_to": reply_to, 

293 **(headers or {}), 

294 } 

295 

296 # https://docs.confluent.io/platform/current/clients/confluent-kafka-python/html/index.html#confluent_kafka.Message.timestamp 

297 return MockConfluentMessage( 

298 raw_msg=msg, 

299 topic=topic, 

300 key=k, 

301 headers=[(i, j.encode()) for i, j in headers.items()], 

302 offset=0, 

303 partition=partition or 0, 

304 timestamp_type=1, 

305 timestamp_ms=timestamp_ms or int(datetime.now(timezone.utc).timestamp() * 1000), 

306 ) 

307 

308 

309def _fake_connection(*args: Any, **kwargs: Any) -> AsyncMock: 

310 mock = AsyncMock() 

311 mock.getone.return_value = MagicMock() 

312 mock.getmany.return_value = [MagicMock()] 

313 return mock 

314 

315 

316def _find_handler( 

317 subscribers: Iterable["LogicSubscriber[Any]"], 

318 topic: str, 

319 partition: int | None, 

320) -> Generator["LogicSubscriber[Any]", None, None]: 

321 published_groups = set() 

322 for handler in subscribers: # pragma: no branch 

323 if _is_handler_matches(handler, topic, partition): 

324 if handler.group_id: 

325 if handler.group_id in published_groups: 

326 continue 

327 else: 

328 published_groups.add(handler.group_id) 

329 yield handler 

330 

331 

332def _is_handler_matches( 

333 handler: "LogicSubscriber[Any]", 

334 topic: str, 

335 partition: int | None, 

336) -> bool: 

337 return bool( 

338 any( 

339 p.topic == topic and (partition is None or p.partition == partition) 

340 for p in handler.partitions 

341 ) 

342 or topic in handler.topics, 

343 )