Coverage for faststream / mqtt / subscriber / usecase.py: 89%
86 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-08 01:48 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-08 01:48 +0000
1import warnings
2from abc import abstractmethod
3from collections.abc import AsyncIterator, Sequence
4from contextlib import suppress
5from typing import TYPE_CHECKING, Any
7import anyio
8import zmqtt
9from typing_extensions import override
11from faststream._internal.endpoint.subscriber import SubscriberUsecase
12from faststream._internal.endpoint.subscriber.mixins import ConcurrentMixin, TasksMixin
13from faststream._internal.endpoint.utils import process_msg
14from faststream.middlewares import AckPolicy
15from faststream.mqtt.parser import MQTTBaseParser, MQTTParserV5, MQTTParserV311
16from faststream.mqtt.publisher.fake import MQTTFakePublisher
18if TYPE_CHECKING:
19 from faststream._internal.endpoint.publisher import PublisherProto
20 from faststream._internal.endpoint.subscriber import SubscriberSpecification
21 from faststream._internal.endpoint.subscriber.call_item import CallsCollection
22 from faststream.message import StreamMessage
23 from faststream.mqtt.broker.config import MQTTBrokerConfig
24 from faststream.mqtt.message import MQTTMessage
25 from faststream.mqtt.subscriber.config import MQTTSubscriberConfig
28class MQTTBaseSubscriber(TasksMixin, SubscriberUsecase[zmqtt.Message]):
29 """Base class for all MQTT subscribers."""
31 _outer_config: "MQTTBrokerConfig"
33 def __init__(
34 self,
35 config: "MQTTSubscriberConfig",
36 specification: "SubscriberSpecification[Any, Any]",
37 calls: "CallsCollection[zmqtt.Message]",
38 ) -> None:
39 # version may not be available yet when subscriber is created on a router
40 # before include_router is called; default to V5 and re-resolve in start().
41 parser: MQTTBaseParser
42 if getattr(config._outer_config, "version", "5.0") == "3.1.1":
43 parser = MQTTParserV311()
44 else:
45 parser = MQTTParserV5()
46 config.parser = parser.parse_message
47 config.decoder = parser.decode_message
48 super().__init__(config, specification, calls)
49 self._topic = config.topic
50 self._shared = config.shared
51 self._qos = config.qos
52 self._subscription: zmqtt.Subscription | None = None
54 if config.ack_policy is AckPolicy.NACK_ON_ERROR: 54 ↛ 55line 54 didn't jump to line 55 because the condition on line 54 was never true
55 warnings.warn(
56 "MQTT has no nack primitive; with NACK_ON_ERROR, "
57 "on error QoS 1/2 messages will not be acknowledged "
58 "and the broker will redeliver them.",
59 RuntimeWarning,
60 stacklevel=3,
61 )
63 @property
64 def topic(self) -> str:
65 full = f"{self._outer_config.prefix}{self._topic}"
66 return f"$share/{self._shared}/{full}" if self._shared else full
68 def _make_response_publisher(
69 self,
70 message: "StreamMessage[Any]",
71 ) -> Sequence["PublisherProto"]:
72 return (
73 MQTTFakePublisher(
74 producer=self._outer_config.producer,
75 topic=message.reply_to,
76 ),
77 )
79 @staticmethod
80 def build_log_context(
81 message: "StreamMessage[zmqtt.Message] | None",
82 topic: str = "",
83 ) -> dict[str, str]:
84 return {
85 "topic": topic,
86 "message_id": getattr(message, "message_id", ""),
87 }
89 def get_log_context(
90 self,
91 message: "StreamMessage[zmqtt.Message] | None",
92 ) -> dict[str, str]:
93 return self.build_log_context(message=message, topic=self.topic)
95 @override
96 async def start(self) -> None:
97 # Re-resolve the parser now that _outer_config is fully composed
98 # (i.e. include_router has been called and the broker's MQTTBrokerConfig
99 # is reachable through the config chain).
100 parser: MQTTBaseParser
101 if getattr(self._outer_config, "version", "5.0") == "3.1.1":
102 parser = MQTTParserV311()
103 else:
104 parser = MQTTParserV5()
105 self._parser = parser.parse_message
106 self._decoder = parser.decode_message
108 await super().start()
110 if self.calls:
111 await self._create_subscription()
112 self.add_task(self._consume_loop)
114 self._post_start()
116 @override
117 async def stop(self) -> None:
118 await super().stop()
119 if self._subscription is not None:
120 with suppress(Exception):
121 await self._subscription.stop()
122 self._subscription = None
124 async def _create_subscription(self) -> None:
125 auto_ack = self.ack_policy is AckPolicy.ACK_FIRST
126 self._subscription = self._outer_config.client.subscribe(
127 self.topic,
128 qos=zmqtt.QoS(self._qos),
129 auto_ack=auto_ack,
130 )
131 await self._subscription.start()
133 @override
134 async def get_one(
135 self,
136 *,
137 timeout: float = 5.0,
138 ) -> "StreamMessage[zmqtt.Message] | None":
139 assert not self.calls, (
140 "You can't use `get_one` method if subscriber has registered handlers."
141 )
143 if self._subscription is None: 143 ↛ 152line 143 didn't jump to line 152 because the condition on line 143 was always true
144 auto_ack = self.ack_policy is AckPolicy.ACK_FIRST
145 self._subscription = self._outer_config.client.subscribe(
146 self.topic,
147 qos=zmqtt.QoS(self._qos),
148 auto_ack=auto_ack,
149 )
150 await self._subscription.start()
152 async_parser, async_decoder = self._get_parser_and_decoder()
154 raw_msg: zmqtt.Message | None = None
155 with anyio.move_on_after(timeout):
156 raw_msg = await self._subscription.get_message()
158 context = self._outer_config.fd_config.context
159 return await process_msg(
160 msg=raw_msg,
161 middlewares=(m(raw_msg, context=context) for m in self._broker_middlewares),
162 parser=async_parser,
163 decoder=async_decoder,
164 )
166 @override
167 async def __aiter__(self) -> AsyncIterator["StreamMessage[zmqtt.Message]"]: # type: ignore[override]
168 if self._subscription is None: 168 ↛ 171line 168 didn't jump to line 171 because the condition on line 168 was always true
169 await self._create_subscription()
171 assert self._subscription is not None
172 context = self._outer_config.fd_config.context
173 async_parser, async_decoder = self._get_parser_and_decoder()
174 async for raw_msg in self._subscription: 174 ↛ exitline 174 didn't return from function '__aiter__' because the loop on line 174 didn't complete
175 msg: MQTTMessage = await process_msg( # type: ignore[assignment]
176 msg=raw_msg,
177 middlewares=(
178 m(raw_msg, context=context) for m in self._broker_middlewares
179 ),
180 parser=async_parser,
181 decoder=async_decoder,
182 )
183 yield msg
185 @abstractmethod
186 async def _consume_loop(self) -> None:
187 raise NotImplementedError
190class MQTTDefaultSubscriber(MQTTBaseSubscriber):
191 """Sequential MQTT subscriber — processes one message at a time."""
193 async def _consume_loop(self) -> None:
194 assert self._subscription is not None
195 async for msg in self._subscription:
196 await self.consume(msg)
199class MQTTConcurrentSubscriber(ConcurrentMixin[zmqtt.Message], MQTTBaseSubscriber):
200 """Concurrent MQTT subscriber — processes up to max_workers messages in parallel."""
202 @override
203 async def start(self) -> None:
204 await super().start()
205 self.start_consume_task()
207 async def _consume_loop(self) -> None:
208 assert self._subscription is not None
209 async for msg in self._subscription:
210 await self._put_msg(msg)