Coverage for faststream / rabbit / testing.py: 96%

118 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-08 01:48 +0000

1from collections.abc import Generator, Iterator, Mapping 

2from contextlib import ExitStack, contextmanager 

3from typing import TYPE_CHECKING, Any, Optional, Union, cast 

4from unittest import mock 

5from unittest.mock import AsyncMock 

6 

7import aiormq 

8import anyio 

9from aio_pika.message import IncomingMessage, encode_expiration 

10from pamqp import commands as spec 

11from pamqp.header import ContentHeader 

12from typing_extensions import override 

13 

14from faststream._internal.endpoint.utils import ParserComposition 

15from faststream._internal.testing.broker import TestBroker, change_producer 

16from faststream.exceptions import SubscriberNotFound 

17from faststream.message import gen_cor_id 

18from faststream.rabbit.broker.broker import RabbitBroker 

19from faststream.rabbit.parser import AioPikaParser 

20from faststream.rabbit.publisher.producer import AioPikaFastProducer 

21from faststream.rabbit.schemas import ( 

22 ExchangeType, 

23 RabbitExchange, 

24 RabbitQueue, 

25) 

26 

27if TYPE_CHECKING: 

28 from aio_pika.abc import DateType, HeadersType 

29 from fast_depends.library.serializer import SerializerProto 

30 

31 from faststream.rabbit.publisher import RabbitPublisher 

32 from faststream.rabbit.response import RabbitPublishCommand 

33 from faststream.rabbit.subscriber import RabbitSubscriber 

34 from faststream.rabbit.types import AioPikaSendableMessage 

35 

36__all__ = ("TestRabbitBroker",) 

37 

38 

39class TestRabbitBroker(TestBroker[RabbitBroker]): 

40 """A class to test RabbitMQ brokers.""" 

41 

42 @contextmanager 

43 def _patch_broker(self, broker: "RabbitBroker") -> Generator[None, None, None]: 

44 with ( 

45 mock.patch.object( 

46 broker, 

47 "_channel", 

48 new_callable=AsyncMock, 

49 ), 

50 mock.patch.object( 

51 broker.config, 

52 "declarer", 

53 new_callable=AsyncMock, 

54 ), 

55 super()._patch_broker(broker), 

56 ): 

57 yield 

58 

59 @contextmanager 

60 def _patch_producer(self, broker: RabbitBroker) -> Iterator[None]: 

61 fake_producer = FakeProducer(broker) 

62 

63 with ExitStack() as es: 

64 es.enter_context( 

65 change_producer(broker.config.broker_config, fake_producer), 

66 ) 

67 yield 

68 

69 @staticmethod 

70 async def _fake_connect(broker: "RabbitBroker", *args: Any, **kwargs: Any) -> None: 

71 pass 

72 

73 @staticmethod 

74 def create_publisher_fake_subscriber( 

75 broker: "RabbitBroker", 

76 publisher: "RabbitPublisher", 

77 ) -> tuple["RabbitSubscriber", bool]: 

78 sub: RabbitSubscriber | None = None 

79 for handler in broker.subscribers: 

80 handler = cast("RabbitSubscriber", handler) 

81 if _is_handler_matches( 

82 handler, 

83 publisher.routing(), 

84 {}, 

85 publisher.exchange, 

86 ): 

87 sub = handler 

88 break 

89 

90 if sub is None: 

91 is_real = False 

92 sub = broker.subscriber( 

93 queue=publisher.routing(), 

94 exchange=publisher.exchange, 

95 persistent=False, 

96 ) 

97 else: 

98 is_real = True 

99 

100 return sub, is_real 

101 

102 

103class PatchedMessage(IncomingMessage): 

104 """Patched message class for testing purposes. 

105 

106 This class extends aio_pika's IncomingMessage class and is used to simulate RabbitMQ message handling during tests. 

107 """ 

108 

109 routing_key: str 

110 

111 async def ack(self, multiple: bool = False) -> None: 

112 """Asynchronously acknowledge a message.""" 

113 

114 async def nack(self, multiple: bool = False, requeue: bool = True) -> None: 

115 """Nack the message.""" 

116 

117 async def reject(self, requeue: bool = False) -> None: 

118 """Rejects a task.""" 

119 

120 

121def build_message( 

122 message: "AioPikaSendableMessage" = "", 

123 queue: Union["RabbitQueue", str] = "", 

124 exchange: Union["RabbitExchange", str, None] = None, 

125 *, 

126 routing_key: str = "", 

127 persist: bool = False, 

128 reply_to: str | None = None, 

129 headers: Optional["HeadersType"] = None, 

130 content_type: str | None = None, 

131 content_encoding: str | None = None, 

132 priority: int | None = None, 

133 correlation_id: str | None = None, 

134 expiration: Optional["DateType"] = None, 

135 message_id: str | None = None, 

136 timestamp: Optional["DateType"] = None, 

137 message_type: str | None = None, 

138 user_id: str | None = None, 

139 app_id: str | None = None, 

140 serializer: Optional["SerializerProto"] = None, 

141) -> PatchedMessage: 

142 """Build a patched RabbitMQ message for testing.""" 

143 que = RabbitQueue.validate(queue) 

144 exch = RabbitExchange.validate(exchange) 

145 

146 routing = routing_key or que.routing() 

147 

148 correlation_id = correlation_id or gen_cor_id() 

149 msg = AioPikaParser.encode_message( 

150 message=message, 

151 persist=persist, 

152 reply_to=reply_to, 

153 headers=headers, 

154 content_type=content_type, 

155 content_encoding=content_encoding, 

156 priority=priority, 

157 correlation_id=correlation_id, 

158 expiration=expiration, 

159 message_id=message_id or correlation_id, 

160 timestamp=timestamp, 

161 message_type=message_type, 

162 user_id=user_id, 

163 app_id=app_id, 

164 serializer=serializer, 

165 ) 

166 

167 return PatchedMessage( 

168 aiormq.abc.DeliveredMessage( 

169 delivery=spec.Basic.Deliver( 

170 exchange=getattr(exch, "name", ""), 

171 routing_key=routing, 

172 ), 

173 header=ContentHeader( 

174 properties=spec.Basic.Properties( 

175 content_type=msg.content_type, 

176 headers=msg.headers, 

177 reply_to=msg.reply_to, 

178 content_encoding=msg.content_encoding, 

179 priority=msg.priority, 

180 correlation_id=msg.correlation_id, 

181 expiration=encode_expiration(msg.expiration), 

182 message_id=msg.message_id, 

183 timestamp=msg.timestamp, 

184 message_type=message_type, 

185 user_id=msg.user_id, 

186 app_id=msg.app_id, 

187 ), 

188 ), 

189 body=msg.body, 

190 channel=AsyncMock(), 

191 ), 

192 ) 

193 

194 

195class FakeProducer(AioPikaFastProducer): 

196 """A fake RabbitMQ producer for testing purposes. 

197 

198 This class extends AioPikaFastProducer and is used to simulate RabbitMQ message publishing during tests. 

199 """ 

200 

201 def __init__(self, broker: RabbitBroker) -> None: 

202 self.broker = broker 

203 

204 default_parser = AioPikaParser() 

205 self._parser = ParserComposition(broker._parser, default_parser.parse_message) 

206 self._decoder = ParserComposition( 

207 broker._decoder, 

208 default_parser.decode_message, 

209 ) 

210 

211 @override 

212 async def publish( 

213 self, 

214 cmd: "RabbitPublishCommand", 

215 ) -> None: 

216 """Publish a message to a RabbitMQ queue or exchange.""" 

217 incoming = build_message( 

218 message=cmd.body, 

219 exchange=cmd.exchange, 

220 routing_key=cmd.destination, 

221 correlation_id=cmd.correlation_id, 

222 headers=cmd.headers, 

223 reply_to=cmd.reply_to, 

224 serializer=self.broker.config.fd_config._serializer, 

225 **cmd.message_options, 

226 ) 

227 

228 called = False 

229 for handler in self.broker.subscribers: # pragma: no branch 

230 handler = cast("RabbitSubscriber", handler) 

231 if _is_handler_matches( 

232 handler, 

233 incoming.routing_key, 

234 incoming.headers, 

235 cmd.exchange, 

236 ): 

237 called = True 

238 await self._execute_handler(incoming, handler) 

239 

240 if not called: 

241 raise SubscriberNotFound 

242 

243 @override 

244 async def request( 

245 self, 

246 cmd: "RabbitPublishCommand", 

247 ) -> "PatchedMessage": 

248 """Make a synchronous request to RabbitMQ.""" 

249 incoming = build_message( 

250 message=cmd.body, 

251 exchange=cmd.exchange, 

252 routing_key=cmd.destination, 

253 correlation_id=cmd.correlation_id, 

254 headers=cmd.headers, 

255 serializer=self.broker.config.fd_config._serializer, 

256 **cmd.message_options, 

257 ) 

258 

259 for handler in self.broker.subscribers: # pragma: no branch 

260 handler = cast("RabbitSubscriber", handler) 

261 if _is_handler_matches( 

262 handler, 

263 incoming.routing_key, 

264 incoming.headers, 

265 cmd.exchange, 

266 ): 

267 with anyio.fail_after(cmd.timeout): 

268 return await self._execute_handler(incoming, handler) 

269 

270 raise SubscriberNotFound 

271 

272 async def _execute_handler( 

273 self, 

274 msg: PatchedMessage, 

275 handler: "RabbitSubscriber", 

276 ) -> "PatchedMessage": 

277 result = await handler.process_message(msg) 

278 return build_message( 

279 routing_key=msg.routing_key, 

280 message=result.body, 

281 headers=result.headers, 

282 correlation_id=result.correlation_id, 

283 serializer=self.broker.config.fd_config._serializer, 

284 ) 

285 

286 

287def _is_handler_matches( 

288 handler: "RabbitSubscriber", 

289 routing_key: str, 

290 headers: Optional["Mapping[Any, Any]"] = None, 

291 exchange: Optional["RabbitExchange"] = None, 

292) -> bool: 

293 headers = headers or {} 

294 exchange = RabbitExchange.validate(exchange) 

295 

296 if handler.exchange != exchange: 

297 return False 

298 

299 if handler.exchange is None or handler.exchange.type == ExchangeType.DIRECT: 

300 return handler.routing() == routing_key 

301 

302 if handler.exchange.type == ExchangeType.FANOUT: 

303 return True 

304 

305 if handler.exchange.type == ExchangeType.TOPIC: 

306 return apply_pattern(handler.routing(), routing_key) 

307 

308 if handler.exchange.type == ExchangeType.HEADERS: 

309 queue_headers = (handler.queue.bind_arguments or {}).copy() 

310 

311 if not queue_headers: 311 ↛ 312line 311 didn't jump to line 312 because the condition on line 311 was never true

312 return True 

313 

314 match_rule = queue_headers.pop("x-match", "all") 

315 

316 full_match = True 

317 is_headers_empty = True 

318 for k, v in queue_headers.items(): 

319 if headers.get(k) != v: 

320 full_match = False 

321 else: 

322 is_headers_empty = False 

323 

324 if is_headers_empty: 

325 return False 

326 

327 return full_match or (match_rule == "any") 

328 

329 raise AssertionError 

330 

331 

332def apply_pattern(pattern: str, current: str) -> bool: 

333 """Apply a pattern to a routing key.""" 

334 pattern_queue = iter(pattern.split(".")) 

335 current_queue = iter(current.split(".")) 

336 

337 pattern_symb = next(pattern_queue, None) 

338 while pattern_symb: 

339 if (next_symb := next(current_queue, None)) is None: 

340 return False 

341 

342 if pattern_symb == "#": 

343 next_pattern = next(pattern_queue, None) 

344 

345 if next_pattern is None: 

346 return True 

347 

348 if (next_symb := next(current_queue, None)) is None: 348 ↛ 349line 348 didn't jump to line 349 because the condition on line 348 was never true

349 return False 

350 

351 while next_pattern == "*": 

352 next_pattern = next(pattern_queue, None) 

353 if (next_symb := next(current_queue, None)) is None: 

354 return False 

355 

356 while next_symb != next_pattern: 

357 if (next_symb := next(current_queue, None)) is None: 357 ↛ 358line 357 didn't jump to line 358 because the condition on line 357 was never true

358 return False 

359 

360 pattern_symb = next(pattern_queue, None) 

361 

362 elif pattern_symb in {"*", next_symb}: 

363 pattern_symb = next(pattern_queue, None) 

364 

365 else: 

366 return False 

367 

368 return next(current_queue, None) is None