Coverage for faststream / mqtt / broker / broker.py: 94%
50 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-08 01:48 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-08 01:48 +0000
1import logging
2from collections.abc import Iterable, Sequence
3from typing import (
4 TYPE_CHECKING,
5 Any,
6 Literal,
7 Optional,
8)
10import zmqtt
11from fast_depends import Provider, dependency_provider
12from typing_extensions import override
14from faststream._internal.broker import BrokerUsecase
15from faststream._internal.constants import EMPTY
16from faststream._internal.context.repository import ContextRepo
17from faststream._internal.di import FastDependsConfig
18from faststream.message import gen_cor_id
19from faststream.middlewares import AckPolicy
20from faststream.mqtt.broker.config import MQTTBrokerConfig
21from faststream.mqtt.publisher.producer import (
22 ZmqttBaseProducer,
23 ZmqttProducerV5,
24 ZmqttProducerV311,
25)
26from faststream.mqtt.response import MQTTPublishCommand
27from faststream.mqtt.security import parse_security
28from faststream.mqtt.subscriber.usecase import MQTTBaseSubscriber
29from faststream.response.publish_type import PublishType
30from faststream.specification.schema import BrokerSpec
32from .logging import make_mqtt_logger_state
33from .registrator import MQTTRegistrator
35if TYPE_CHECKING:
36 from types import TracebackType
38 from fast_depends.dependencies import Dependant
39 from fast_depends.library.serializer import SerializerProto
41 from faststream._internal.basic_types import LoggerProto, SendableMessage
42 from faststream._internal.parser import CodecProto
43 from faststream._internal.types import BrokerMiddleware, CustomCallable
44 from faststream.mqtt.message import MQTTMessage
45 from faststream.security import BaseSecurity
46 from faststream.specification.schema.extra import Tag, TagDict
49class MQTTBroker(
50 MQTTRegistrator,
51 BrokerUsecase[zmqtt.Message, zmqtt.MQTTClient],
52):
53 """MQTT broker for FastStream using the zmqtt client library."""
55 def __init__(
56 self,
57 host: str = "localhost:1883",
58 port: int = EMPTY,
59 *,
60 client_id: str = "",
61 keepalive: int = 60,
62 clean_session: bool = True,
63 version: Literal["3.1.1", "5.0"] = "5.0",
64 reconnect: zmqtt.ReconnectConfig | None = None,
65 session_expiry_interval: int = 0,
66 graceful_timeout: float | None = 15.0,
67 decoder: Optional["CustomCallable"] = None,
68 parser: Optional["CustomCallable"] = None,
69 codec: Optional["CodecProto"] = None,
70 dependencies: Iterable["Dependant"] = (),
71 middlewares: Sequence["BrokerMiddleware[Any, Any]"] = (),
72 routers: Iterable[MQTTRegistrator] = (),
73 ack_policy: AckPolicy = EMPTY,
74 # AsyncAPI args
75 specification_url: str | None = None,
76 protocol_version: str | None = None,
77 description: str | None = None,
78 tags: Iterable["Tag | TagDict"] = (),
79 security: Optional["BaseSecurity"] = None,
80 # logging args
81 logger: Optional["LoggerProto"] = EMPTY,
82 log_level: int = logging.INFO,
83 # FastDepends args
84 apply_types: bool = True,
85 serializer: Optional["SerializerProto"] = EMPTY,
86 provider: Optional["Provider"] = None,
87 context: Optional["ContextRepo"] = None,
88 ) -> None:
89 secure_kwargs = parse_security(security)
91 producer: ZmqttBaseProducer
92 if version == "5.0":
93 producer = ZmqttProducerV5(parser=parser, decoder=decoder)
94 else:
95 producer = ZmqttProducerV311(parser=parser, decoder=decoder)
97 if ":" in host:
98 host, p = host.split(":", 2)
99 else:
100 p = "1883"
101 if port is EMPTY:
102 port = int(p)
104 if specification_url is None:
105 specification_url = f"mqtt://{host}:{port}"
107 super().__init__(
108 host=host,
109 port=port,
110 client_id=client_id,
111 keepalive=keepalive,
112 clean_session=clean_session,
113 version=version,
114 reconnect=reconnect,
115 session_expiry_interval=session_expiry_interval,
116 **secure_kwargs,
117 # broker config
118 routers=routers,
119 config=MQTTBrokerConfig(
120 version=version,
121 producer=producer,
122 broker_middlewares=middlewares,
123 broker_parser=parser,
124 broker_decoder=decoder,
125 broker_codec=codec,
126 logger=make_mqtt_logger_state(
127 logger=logger,
128 log_level=log_level,
129 ),
130 fd_config=FastDependsConfig(
131 use_fastdepends=apply_types,
132 serializer=serializer,
133 provider=provider or dependency_provider,
134 context=context or ContextRepo(),
135 ),
136 broker_dependencies=dependencies,
137 graceful_timeout=graceful_timeout,
138 ack_policy=ack_policy,
139 extra_context={
140 "broker": self,
141 },
142 ),
143 specification=BrokerSpec(
144 description=description,
145 url=[specification_url],
146 protocol="mqtt",
147 protocol_version=protocol_version or version,
148 tags=tags,
149 security=security,
150 ),
151 )
153 @override
154 async def _connect(self) -> zmqtt.MQTTClient:
155 client = zmqtt.MQTTClient(**self._connection_kwargs)
156 await client.connect()
157 self.config.connect(client)
158 return client
160 @override
161 async def start(self) -> None:
162 await self.connect()
163 c = MQTTBaseSubscriber.build_log_context(None, "")
164 self.config.logger.log("Connection established", logging.INFO, c)
165 await super().start()
167 @override
168 async def stop(
169 self,
170 exc_type: type[BaseException] | None = None,
171 exc_val: BaseException | None = None,
172 exc_tb: Optional["TracebackType"] = None,
173 ) -> None:
174 await super().stop(exc_type, exc_val, exc_tb)
176 if self._connection is not None:
177 await self._connection.disconnect()
178 self._connection = None
180 self.config.disconnect()
182 @override
183 async def ping(self, timeout: float | None = None) -> bool:
184 if self._connection is None: 184 ↛ 185line 184 didn't jump to line 185 because the condition on line 184 was never true
185 return False
186 try:
187 await self._connection.ping(timeout=timeout or 5.0)
188 except Exception:
189 return False
190 else:
191 return True
193 @override
194 async def publish(
195 self,
196 message: "SendableMessage" = None,
197 topic: str = "",
198 *,
199 qos: zmqtt.QoS = zmqtt.QoS.AT_MOST_ONCE,
200 retain: bool = False,
201 headers: dict[str, str] | None = None,
202 correlation_id: str | None = None,
203 reply_to: str = "",
204 ) -> None:
205 """Publish a message to an MQTT topic.
207 Args:
208 message: Message body to send.
209 topic: MQTT topic to publish to.
210 qos: QoS level (0, 1, or 2).
211 retain: Whether the broker should retain the message.
212 headers: Message headers (MQTT 5.0 user properties).
213 correlation_id: Correlation ID for message tracing.
214 reply_to: Response topic (MQTT 5.0 response_topic property).
215 """
216 cmd = MQTTPublishCommand(
217 message,
218 topic=topic,
219 qos=qos,
220 retain=retain,
221 headers=headers,
222 correlation_id=correlation_id or gen_cor_id(),
223 reply_to=reply_to,
224 _publish_type=PublishType.PUBLISH,
225 )
227 await self._basic_publish(cmd, producer=self.config.producer)
229 @override
230 async def request(
231 self,
232 message: "SendableMessage" = None,
233 topic: str = "",
234 /,
235 timeout: float = 0.5,
236 correlation_id: str | None = None,
237 headers: dict[str, str] | None = None,
238 qos: zmqtt.QoS = zmqtt.QoS.AT_MOST_ONCE,
239 reply_to: str = "",
240 ) -> "MQTTMessage":
241 cmd = MQTTPublishCommand(
242 message,
243 topic=topic,
244 correlation_id=correlation_id or gen_cor_id(),
245 headers=headers,
246 qos=qos,
247 reply_to=reply_to,
248 timeout=timeout,
249 _publish_type=PublishType.REQUEST,
250 )
251 msg: MQTTMessage = await self._basic_request(cmd, producer=self.config.producer)
252 return msg