Coverage for faststream / mqtt / broker / broker.py: 94%

50 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-08 01:48 +0000

1import logging 

2from collections.abc import Iterable, Sequence 

3from typing import ( 

4 TYPE_CHECKING, 

5 Any, 

6 Literal, 

7 Optional, 

8) 

9 

10import zmqtt 

11from fast_depends import Provider, dependency_provider 

12from typing_extensions import override 

13 

14from faststream._internal.broker import BrokerUsecase 

15from faststream._internal.constants import EMPTY 

16from faststream._internal.context.repository import ContextRepo 

17from faststream._internal.di import FastDependsConfig 

18from faststream.message import gen_cor_id 

19from faststream.middlewares import AckPolicy 

20from faststream.mqtt.broker.config import MQTTBrokerConfig 

21from faststream.mqtt.publisher.producer import ( 

22 ZmqttBaseProducer, 

23 ZmqttProducerV5, 

24 ZmqttProducerV311, 

25) 

26from faststream.mqtt.response import MQTTPublishCommand 

27from faststream.mqtt.security import parse_security 

28from faststream.mqtt.subscriber.usecase import MQTTBaseSubscriber 

29from faststream.response.publish_type import PublishType 

30from faststream.specification.schema import BrokerSpec 

31 

32from .logging import make_mqtt_logger_state 

33from .registrator import MQTTRegistrator 

34 

35if TYPE_CHECKING: 

36 from types import TracebackType 

37 

38 from fast_depends.dependencies import Dependant 

39 from fast_depends.library.serializer import SerializerProto 

40 

41 from faststream._internal.basic_types import LoggerProto, SendableMessage 

42 from faststream._internal.parser import CodecProto 

43 from faststream._internal.types import BrokerMiddleware, CustomCallable 

44 from faststream.mqtt.message import MQTTMessage 

45 from faststream.security import BaseSecurity 

46 from faststream.specification.schema.extra import Tag, TagDict 

47 

48 

49class MQTTBroker( 

50 MQTTRegistrator, 

51 BrokerUsecase[zmqtt.Message, zmqtt.MQTTClient], 

52): 

53 """MQTT broker for FastStream using the zmqtt client library.""" 

54 

55 def __init__( 

56 self, 

57 host: str = "localhost:1883", 

58 port: int = EMPTY, 

59 *, 

60 client_id: str = "", 

61 keepalive: int = 60, 

62 clean_session: bool = True, 

63 version: Literal["3.1.1", "5.0"] = "5.0", 

64 reconnect: zmqtt.ReconnectConfig | None = None, 

65 session_expiry_interval: int = 0, 

66 graceful_timeout: float | None = 15.0, 

67 decoder: Optional["CustomCallable"] = None, 

68 parser: Optional["CustomCallable"] = None, 

69 codec: Optional["CodecProto"] = None, 

70 dependencies: Iterable["Dependant"] = (), 

71 middlewares: Sequence["BrokerMiddleware[Any, Any]"] = (), 

72 routers: Iterable[MQTTRegistrator] = (), 

73 ack_policy: AckPolicy = EMPTY, 

74 # AsyncAPI args 

75 specification_url: str | None = None, 

76 protocol_version: str | None = None, 

77 description: str | None = None, 

78 tags: Iterable["Tag | TagDict"] = (), 

79 security: Optional["BaseSecurity"] = None, 

80 # logging args 

81 logger: Optional["LoggerProto"] = EMPTY, 

82 log_level: int = logging.INFO, 

83 # FastDepends args 

84 apply_types: bool = True, 

85 serializer: Optional["SerializerProto"] = EMPTY, 

86 provider: Optional["Provider"] = None, 

87 context: Optional["ContextRepo"] = None, 

88 ) -> None: 

89 secure_kwargs = parse_security(security) 

90 

91 producer: ZmqttBaseProducer 

92 if version == "5.0": 

93 producer = ZmqttProducerV5(parser=parser, decoder=decoder) 

94 else: 

95 producer = ZmqttProducerV311(parser=parser, decoder=decoder) 

96 

97 if ":" in host: 

98 host, p = host.split(":", 2) 

99 else: 

100 p = "1883" 

101 if port is EMPTY: 

102 port = int(p) 

103 

104 if specification_url is None: 

105 specification_url = f"mqtt://{host}:{port}" 

106 

107 super().__init__( 

108 host=host, 

109 port=port, 

110 client_id=client_id, 

111 keepalive=keepalive, 

112 clean_session=clean_session, 

113 version=version, 

114 reconnect=reconnect, 

115 session_expiry_interval=session_expiry_interval, 

116 **secure_kwargs, 

117 # broker config 

118 routers=routers, 

119 config=MQTTBrokerConfig( 

120 version=version, 

121 producer=producer, 

122 broker_middlewares=middlewares, 

123 broker_parser=parser, 

124 broker_decoder=decoder, 

125 broker_codec=codec, 

126 logger=make_mqtt_logger_state( 

127 logger=logger, 

128 log_level=log_level, 

129 ), 

130 fd_config=FastDependsConfig( 

131 use_fastdepends=apply_types, 

132 serializer=serializer, 

133 provider=provider or dependency_provider, 

134 context=context or ContextRepo(), 

135 ), 

136 broker_dependencies=dependencies, 

137 graceful_timeout=graceful_timeout, 

138 ack_policy=ack_policy, 

139 extra_context={ 

140 "broker": self, 

141 }, 

142 ), 

143 specification=BrokerSpec( 

144 description=description, 

145 url=[specification_url], 

146 protocol="mqtt", 

147 protocol_version=protocol_version or version, 

148 tags=tags, 

149 security=security, 

150 ), 

151 ) 

152 

153 @override 

154 async def _connect(self) -> zmqtt.MQTTClient: 

155 client = zmqtt.MQTTClient(**self._connection_kwargs) 

156 await client.connect() 

157 self.config.connect(client) 

158 return client 

159 

160 @override 

161 async def start(self) -> None: 

162 await self.connect() 

163 c = MQTTBaseSubscriber.build_log_context(None, "") 

164 self.config.logger.log("Connection established", logging.INFO, c) 

165 await super().start() 

166 

167 @override 

168 async def stop( 

169 self, 

170 exc_type: type[BaseException] | None = None, 

171 exc_val: BaseException | None = None, 

172 exc_tb: Optional["TracebackType"] = None, 

173 ) -> None: 

174 await super().stop(exc_type, exc_val, exc_tb) 

175 

176 if self._connection is not None: 

177 await self._connection.disconnect() 

178 self._connection = None 

179 

180 self.config.disconnect() 

181 

182 @override 

183 async def ping(self, timeout: float | None = None) -> bool: 

184 if self._connection is None: 184 ↛ 185line 184 didn't jump to line 185 because the condition on line 184 was never true

185 return False 

186 try: 

187 await self._connection.ping(timeout=timeout or 5.0) 

188 except Exception: 

189 return False 

190 else: 

191 return True 

192 

193 @override 

194 async def publish( 

195 self, 

196 message: "SendableMessage" = None, 

197 topic: str = "", 

198 *, 

199 qos: zmqtt.QoS = zmqtt.QoS.AT_MOST_ONCE, 

200 retain: bool = False, 

201 headers: dict[str, str] | None = None, 

202 correlation_id: str | None = None, 

203 reply_to: str = "", 

204 ) -> None: 

205 """Publish a message to an MQTT topic. 

206 

207 Args: 

208 message: Message body to send. 

209 topic: MQTT topic to publish to. 

210 qos: QoS level (0, 1, or 2). 

211 retain: Whether the broker should retain the message. 

212 headers: Message headers (MQTT 5.0 user properties). 

213 correlation_id: Correlation ID for message tracing. 

214 reply_to: Response topic (MQTT 5.0 response_topic property). 

215 """ 

216 cmd = MQTTPublishCommand( 

217 message, 

218 topic=topic, 

219 qos=qos, 

220 retain=retain, 

221 headers=headers, 

222 correlation_id=correlation_id or gen_cor_id(), 

223 reply_to=reply_to, 

224 _publish_type=PublishType.PUBLISH, 

225 ) 

226 

227 await self._basic_publish(cmd, producer=self.config.producer) 

228 

229 @override 

230 async def request( 

231 self, 

232 message: "SendableMessage" = None, 

233 topic: str = "", 

234 /, 

235 timeout: float = 0.5, 

236 correlation_id: str | None = None, 

237 headers: dict[str, str] | None = None, 

238 qos: zmqtt.QoS = zmqtt.QoS.AT_MOST_ONCE, 

239 reply_to: str = "", 

240 ) -> "MQTTMessage": 

241 cmd = MQTTPublishCommand( 

242 message, 

243 topic=topic, 

244 correlation_id=correlation_id or gen_cor_id(), 

245 headers=headers, 

246 qos=qos, 

247 reply_to=reply_to, 

248 timeout=timeout, 

249 _publish_type=PublishType.REQUEST, 

250 ) 

251 msg: MQTTMessage = await self._basic_request(cmd, producer=self.config.producer) 

252 return msg