Coverage for faststream / rabbit / subscriber / usecase.py: 99%
71 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-08 01:48 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-08 01:48 +0000
1import asyncio
2import contextlib
3from collections.abc import AsyncIterator, Sequence
4from typing import TYPE_CHECKING, Any, Optional, cast
6import anyio
7from typing_extensions import override
9from faststream._internal.endpoint.subscriber import SubscriberUsecase
10from faststream._internal.endpoint.utils import process_msg
11from faststream.rabbit.parser import AioPikaParser
12from faststream.rabbit.publisher.fake import RabbitFakePublisher
13from faststream.rabbit.schemas import RabbitExchange
14from faststream.rabbit.schemas.constants import REPLY_TO_QUEUE_EXCHANGE_DELIMITER
16if TYPE_CHECKING:
17 from aio_pika import IncomingMessage, RobustQueue
19 from faststream._internal.endpoint.publisher import PublisherProto
20 from faststream._internal.endpoint.subscriber.call_item import CallsCollection
21 from faststream._internal.endpoint.subscriber.specification import (
22 SubscriberSpecification,
23 )
24 from faststream.message import StreamMessage
25 from faststream.rabbit.configs import RabbitBrokerConfig
26 from faststream.rabbit.message import RabbitMessage
27 from faststream.rabbit.schemas import RabbitQueue
29 from .config import RabbitSubscriberConfig
32class RabbitSubscriber(SubscriberUsecase["IncomingMessage"]):
33 """A class to handle logic for RabbitMQ message consumption."""
35 _outer_config: "RabbitBrokerConfig"
37 def __init__(
38 self,
39 config: "RabbitSubscriberConfig",
40 specification: "SubscriberSpecification[Any, Any]",
41 calls: "CallsCollection[IncomingMessage]",
42 ) -> None:
43 parser = AioPikaParser(pattern=config.queue.path_regex)
44 config.decoder = parser.decode_message
45 config.parser = parser.parse_message
46 super().__init__(
47 config,
48 specification=specification,
49 calls=calls,
50 )
52 self.queue = config.queue
53 self.exchange = config.exchange
55 self.consume_args = config.consume_args or {}
57 self.__no_ack = config.ack_first
59 self._consumer_tag: str | None = None
60 self._queue_obj: RobustQueue | None = None
61 self.channel = config.channel
63 @property
64 def app_id(self) -> str | None:
65 return self._outer_config.app_id
67 def routing(self) -> str:
68 return f"{self._outer_config.prefix}{self.queue.routing()}"
70 @override
71 async def start(self) -> None:
72 """Starts the consumer for the RabbitMQ queue."""
73 await super().start()
75 queue_to_bind = self.queue.add_prefix(self._outer_config.prefix)
77 declarer = self._outer_config.declarer
79 self._queue_obj = queue = await declarer.declare_queue(
80 queue_to_bind,
81 channel=self.channel,
82 )
84 if (
85 self.exchange is not None
86 and queue_to_bind.declare # queue just getted from RMQ
87 and self.exchange.name # check Exchange is not default
88 ):
89 exchange = await declarer.declare_exchange(
90 self.exchange,
91 channel=self.channel,
92 )
94 await queue.bind(
95 exchange,
96 routing_key=queue_to_bind.routing(),
97 arguments=queue_to_bind.bind_arguments,
98 timeout=queue_to_bind.timeout,
99 robust=self.queue.robust,
100 )
102 if self.calls:
103 self._consumer_tag = await self._queue_obj.consume(
104 # NOTE: aio-pika expects AbstractIncomingMessage, not IncomingMessage
105 self.consume, # type: ignore[arg-type]
106 no_ack=self.__no_ack,
107 arguments=self.consume_args,
108 )
110 self._post_start()
112 async def stop(self) -> None:
113 await super().stop()
115 if self._queue_obj is not None:
116 if self._consumer_tag is not None: # pragma: no branch
117 if not self._queue_obj.channel.is_closed:
118 await self._queue_obj.cancel(self._consumer_tag)
119 self._consumer_tag = None
121 self._queue_obj = None
123 @override
124 async def get_one(
125 self,
126 *,
127 timeout: float = 5.0,
128 no_ack: bool = True,
129 ) -> "RabbitMessage | None":
130 assert self._queue_obj, "You should start subscriber at first."
131 assert not self.calls, (
132 "You can't use `get_one` method if subscriber has registered handlers."
133 )
135 sleep_interval = timeout / 10
137 raw_message: IncomingMessage | None = None
138 with (
139 contextlib.suppress(asyncio.exceptions.CancelledError),
140 anyio.move_on_after(timeout),
141 ):
142 while ( # noqa: ASYNC110
143 raw_message := await self._queue_obj.get(
144 fail=False,
145 no_ack=no_ack,
146 timeout=timeout,
147 )
148 ) is None:
149 await anyio.sleep(sleep_interval)
151 context = self._outer_config.fd_config.context
152 async_parser, async_decoder = self._get_parser_and_decoder()
154 msg: RabbitMessage | None = await process_msg( # type: ignore[assignment]
155 msg=raw_message,
156 middlewares=(
157 m(raw_message, context=context) for m in self._broker_middlewares
158 ),
159 parser=async_parser,
160 decoder=async_decoder,
161 )
162 return msg
164 @override
165 async def __aiter__(self) -> AsyncIterator["RabbitMessage"]: # type: ignore[override]
166 assert self._queue_obj, "You should start subscriber at first."
167 assert not self.calls, (
168 "You can't use iterator method if subscriber has registered handlers."
169 )
171 context = self._outer_config.fd_config.context
172 async_parser, async_decoder = self._get_parser_and_decoder()
174 async with self._queue_obj.iterator() as queue_iter:
175 async for raw_message in queue_iter: 175 ↛ exitline 175 didn't jump to the function exit
176 raw_message = cast("IncomingMessage", raw_message)
178 msg: RabbitMessage = await process_msg( # type: ignore[assignment]
179 msg=raw_message,
180 middlewares=(
181 m(raw_message, context=context) for m in self._broker_middlewares
182 ),
183 parser=async_parser,
184 decoder=async_decoder,
185 )
186 yield msg
188 def _make_response_publisher(
189 self,
190 message: "StreamMessage[Any]",
191 ) -> Sequence["PublisherProto"]:
192 if REPLY_TO_QUEUE_EXCHANGE_DELIMITER in message.reply_to:
193 queue_name, exchange_name = message.reply_to.split(
194 REPLY_TO_QUEUE_EXCHANGE_DELIMITER, 2
195 )
196 publisher = RabbitFakePublisher(
197 self._outer_config.producer,
198 app_id=self.app_id,
199 routing_key=queue_name,
200 exchange=RabbitExchange.validate(exchange_name),
201 )
202 else:
203 publisher = RabbitFakePublisher(
204 self._outer_config.producer,
205 app_id=self.app_id,
206 routing_key=message.reply_to,
207 exchange=RabbitExchange(),
208 )
210 return (publisher,)
212 @staticmethod
213 def build_log_context(
214 message: Optional["StreamMessage[Any]"],
215 queue: "RabbitQueue",
216 exchange: Optional["RabbitExchange"] = None,
217 ) -> dict[str, str]:
218 return {
219 "queue": queue.name,
220 "exchange": getattr(exchange, "name", ""),
221 "message_id": getattr(message, "message_id", ""),
222 }
224 def get_log_context(
225 self,
226 message: Optional["StreamMessage[Any]"],
227 ) -> dict[str, str]:
228 return self.build_log_context(
229 message=message,
230 queue=self.queue,
231 exchange=self.exchange,
232 )