Coverage for faststream / rabbit / subscriber / usecase.py: 99%

71 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-08 01:48 +0000

1import asyncio 

2import contextlib 

3from collections.abc import AsyncIterator, Sequence 

4from typing import TYPE_CHECKING, Any, Optional, cast 

5 

6import anyio 

7from typing_extensions import override 

8 

9from faststream._internal.endpoint.subscriber import SubscriberUsecase 

10from faststream._internal.endpoint.utils import process_msg 

11from faststream.rabbit.parser import AioPikaParser 

12from faststream.rabbit.publisher.fake import RabbitFakePublisher 

13from faststream.rabbit.schemas import RabbitExchange 

14from faststream.rabbit.schemas.constants import REPLY_TO_QUEUE_EXCHANGE_DELIMITER 

15 

16if TYPE_CHECKING: 

17 from aio_pika import IncomingMessage, RobustQueue 

18 

19 from faststream._internal.endpoint.publisher import PublisherProto 

20 from faststream._internal.endpoint.subscriber.call_item import CallsCollection 

21 from faststream._internal.endpoint.subscriber.specification import ( 

22 SubscriberSpecification, 

23 ) 

24 from faststream.message import StreamMessage 

25 from faststream.rabbit.configs import RabbitBrokerConfig 

26 from faststream.rabbit.message import RabbitMessage 

27 from faststream.rabbit.schemas import RabbitQueue 

28 

29 from .config import RabbitSubscriberConfig 

30 

31 

32class RabbitSubscriber(SubscriberUsecase["IncomingMessage"]): 

33 """A class to handle logic for RabbitMQ message consumption.""" 

34 

35 _outer_config: "RabbitBrokerConfig" 

36 

37 def __init__( 

38 self, 

39 config: "RabbitSubscriberConfig", 

40 specification: "SubscriberSpecification[Any, Any]", 

41 calls: "CallsCollection[IncomingMessage]", 

42 ) -> None: 

43 parser = AioPikaParser(pattern=config.queue.path_regex) 

44 config.decoder = parser.decode_message 

45 config.parser = parser.parse_message 

46 super().__init__( 

47 config, 

48 specification=specification, 

49 calls=calls, 

50 ) 

51 

52 self.queue = config.queue 

53 self.exchange = config.exchange 

54 

55 self.consume_args = config.consume_args or {} 

56 

57 self.__no_ack = config.ack_first 

58 

59 self._consumer_tag: str | None = None 

60 self._queue_obj: RobustQueue | None = None 

61 self.channel = config.channel 

62 

63 @property 

64 def app_id(self) -> str | None: 

65 return self._outer_config.app_id 

66 

67 def routing(self) -> str: 

68 return f"{self._outer_config.prefix}{self.queue.routing()}" 

69 

70 @override 

71 async def start(self) -> None: 

72 """Starts the consumer for the RabbitMQ queue.""" 

73 await super().start() 

74 

75 queue_to_bind = self.queue.add_prefix(self._outer_config.prefix) 

76 

77 declarer = self._outer_config.declarer 

78 

79 self._queue_obj = queue = await declarer.declare_queue( 

80 queue_to_bind, 

81 channel=self.channel, 

82 ) 

83 

84 if ( 

85 self.exchange is not None 

86 and queue_to_bind.declare # queue just getted from RMQ 

87 and self.exchange.name # check Exchange is not default 

88 ): 

89 exchange = await declarer.declare_exchange( 

90 self.exchange, 

91 channel=self.channel, 

92 ) 

93 

94 await queue.bind( 

95 exchange, 

96 routing_key=queue_to_bind.routing(), 

97 arguments=queue_to_bind.bind_arguments, 

98 timeout=queue_to_bind.timeout, 

99 robust=self.queue.robust, 

100 ) 

101 

102 if self.calls: 

103 self._consumer_tag = await self._queue_obj.consume( 

104 # NOTE: aio-pika expects AbstractIncomingMessage, not IncomingMessage 

105 self.consume, # type: ignore[arg-type] 

106 no_ack=self.__no_ack, 

107 arguments=self.consume_args, 

108 ) 

109 

110 self._post_start() 

111 

112 async def stop(self) -> None: 

113 await super().stop() 

114 

115 if self._queue_obj is not None: 

116 if self._consumer_tag is not None: # pragma: no branch 

117 if not self._queue_obj.channel.is_closed: 

118 await self._queue_obj.cancel(self._consumer_tag) 

119 self._consumer_tag = None 

120 

121 self._queue_obj = None 

122 

123 @override 

124 async def get_one( 

125 self, 

126 *, 

127 timeout: float = 5.0, 

128 no_ack: bool = True, 

129 ) -> "RabbitMessage | None": 

130 assert self._queue_obj, "You should start subscriber at first." 

131 assert not self.calls, ( 

132 "You can't use `get_one` method if subscriber has registered handlers." 

133 ) 

134 

135 sleep_interval = timeout / 10 

136 

137 raw_message: IncomingMessage | None = None 

138 with ( 

139 contextlib.suppress(asyncio.exceptions.CancelledError), 

140 anyio.move_on_after(timeout), 

141 ): 

142 while ( # noqa: ASYNC110 

143 raw_message := await self._queue_obj.get( 

144 fail=False, 

145 no_ack=no_ack, 

146 timeout=timeout, 

147 ) 

148 ) is None: 

149 await anyio.sleep(sleep_interval) 

150 

151 context = self._outer_config.fd_config.context 

152 async_parser, async_decoder = self._get_parser_and_decoder() 

153 

154 msg: RabbitMessage | None = await process_msg( # type: ignore[assignment] 

155 msg=raw_message, 

156 middlewares=( 

157 m(raw_message, context=context) for m in self._broker_middlewares 

158 ), 

159 parser=async_parser, 

160 decoder=async_decoder, 

161 ) 

162 return msg 

163 

164 @override 

165 async def __aiter__(self) -> AsyncIterator["RabbitMessage"]: # type: ignore[override] 

166 assert self._queue_obj, "You should start subscriber at first." 

167 assert not self.calls, ( 

168 "You can't use iterator method if subscriber has registered handlers." 

169 ) 

170 

171 context = self._outer_config.fd_config.context 

172 async_parser, async_decoder = self._get_parser_and_decoder() 

173 

174 async with self._queue_obj.iterator() as queue_iter: 

175 async for raw_message in queue_iter: 175 ↛ exitline 175 didn't jump to the function exit

176 raw_message = cast("IncomingMessage", raw_message) 

177 

178 msg: RabbitMessage = await process_msg( # type: ignore[assignment] 

179 msg=raw_message, 

180 middlewares=( 

181 m(raw_message, context=context) for m in self._broker_middlewares 

182 ), 

183 parser=async_parser, 

184 decoder=async_decoder, 

185 ) 

186 yield msg 

187 

188 def _make_response_publisher( 

189 self, 

190 message: "StreamMessage[Any]", 

191 ) -> Sequence["PublisherProto"]: 

192 if REPLY_TO_QUEUE_EXCHANGE_DELIMITER in message.reply_to: 

193 queue_name, exchange_name = message.reply_to.split( 

194 REPLY_TO_QUEUE_EXCHANGE_DELIMITER, 2 

195 ) 

196 publisher = RabbitFakePublisher( 

197 self._outer_config.producer, 

198 app_id=self.app_id, 

199 routing_key=queue_name, 

200 exchange=RabbitExchange.validate(exchange_name), 

201 ) 

202 else: 

203 publisher = RabbitFakePublisher( 

204 self._outer_config.producer, 

205 app_id=self.app_id, 

206 routing_key=message.reply_to, 

207 exchange=RabbitExchange(), 

208 ) 

209 

210 return (publisher,) 

211 

212 @staticmethod 

213 def build_log_context( 

214 message: Optional["StreamMessage[Any]"], 

215 queue: "RabbitQueue", 

216 exchange: Optional["RabbitExchange"] = None, 

217 ) -> dict[str, str]: 

218 return { 

219 "queue": queue.name, 

220 "exchange": getattr(exchange, "name", ""), 

221 "message_id": getattr(message, "message_id", ""), 

222 } 

223 

224 def get_log_context( 

225 self, 

226 message: Optional["StreamMessage[Any]"], 

227 ) -> dict[str, str]: 

228 return self.build_log_context( 

229 message=message, 

230 queue=self.queue, 

231 exchange=self.exchange, 

232 )