Coverage for faststream / mqtt / testing.py: 95%

89 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-08 01:48 +0000

1import asyncio 

2from collections.abc import Iterator 

3from contextlib import contextmanager 

4from typing import TYPE_CHECKING, Any, Literal, Optional, cast 

5from unittest.mock import MagicMock 

6 

7import anyio 

8import zmqtt 

9from typing_extensions import override 

10from zmqtt._internal.protocol import _shared_filter_to_actual, _topic_matches 

11 

12from faststream._internal.endpoint.utils import ParserComposition 

13from faststream._internal.testing.broker import TestBroker, change_producer 

14from faststream.exceptions import SubscriberNotFound 

15from faststream.message import encode_message 

16from faststream.mqtt.broker.broker import MQTTBroker 

17from faststream.mqtt.parser import MQTTParserV5, MQTTParserV311 

18from faststream.mqtt.publisher.producer import ZmqttBaseProducer 

19from faststream.mqtt.response import MQTTPublishCommand 

20 

21if TYPE_CHECKING: 

22 from fast_depends.library.serializer import SerializerProto 

23 

24 from faststream._internal.basic_types import SendableMessage 

25 from faststream.mqtt.publisher.usecase import MQTTPublisher 

26 from faststream.mqtt.subscriber.usecase import MQTTBaseSubscriber 

27 

28__all__ = ("TestMQTTBroker",) 

29 

30 

31class _BlockingSubscription: 

32 """Fake zmqtt.Subscription that blocks forever on iteration. 

33 

34 Used by ``TestMQTTBroker`` so dynamic subscribers can call 

35 ``start()`` without a real MQTT connection. Message routing 

36 happens through ``FakeProducer``, not through this iterator. 

37 """ 

38 

39 async def start(self) -> None: 

40 pass 

41 

42 async def stop(self) -> None: 

43 pass 

44 

45 def __aiter__(self) -> "_BlockingSubscription": 

46 return self 

47 

48 async def __anext__(self) -> zmqtt.Message: 

49 # Block until the task is cancelled (i.e. subscriber.stop() is called) 

50 await asyncio.sleep(1e9) 

51 raise StopAsyncIteration # pragma: no cover 

52 

53 

54def mqtt_topic_matches(pattern: str, topic: str) -> bool: 

55 return _topic_matches(_shared_filter_to_actual(pattern), topic) 

56 

57 

58def _broker_version(broker: MQTTBroker) -> Literal["3.1.1", "5.0"]: 

59 return getattr(broker.config.broker_config, "version", "5.0") 

60 

61 

62def _parser_for_version( 

63 version: Literal["3.1.1", "5.0"], 

64) -> MQTTParserV311 | MQTTParserV5: 

65 return MQTTParserV311() if version == "3.1.1" else MQTTParserV5() 

66 

67 

68class TestMQTTBroker(TestBroker[MQTTBroker]): 

69 """In-memory test double for MQTTBroker. 

70 

71 Routes published messages to matching subscribers without a real 

72 MQTT connection, using MQTT wildcard rules for topic matching. 

73 Messages are encoded in the same wire format as the configured 

74 broker version (V311 envelope or V5 PublishProperties). 

75 

76 Usage:: 

77 

78 async with TestMQTTBroker(broker) as br: 

79 await br.publish("hello", "sensors/temp") 

80 handler.mock.assert_called_once_with("hello") 

81 """ 

82 

83 @staticmethod 

84 def create_publisher_fake_subscriber( 

85 broker: MQTTBroker, 

86 publisher: "MQTTPublisher", 

87 ) -> tuple["MQTTBaseSubscriber", bool]: 

88 sub: MQTTBaseSubscriber | None = None 

89 for handler in broker.subscribers: 

90 handler = cast("MQTTBaseSubscriber", handler) 

91 if mqtt_topic_matches(handler.topic, publisher.topic): 

92 sub = handler 

93 break 

94 

95 if sub is None: 

96 is_real = False 

97 sub = broker.subscriber(publisher.topic, persistent=False) 

98 # Apply the correct version parser so fake subs match FakeProducer output. 

99 parser = _parser_for_version(_broker_version(broker)) 

100 sub._parser = parser.parse_message 

101 sub._decoder = parser.decode_message 

102 else: 

103 is_real = True 

104 

105 return sub, is_real 

106 

107 def _fake_start(self, broker: MQTTBroker, *args: Any, **kwargs: Any) -> None: 

108 # Ensure all pre-existing subscribers use the version-correct parser 

109 # before patch_broker_calls builds the fastdepends model. 

110 parser = _parser_for_version(_broker_version(broker)) 

111 for sub in cast("list[MQTTBaseSubscriber]", broker.subscribers): 

112 sub._parser = parser.parse_message 

113 sub._decoder = parser.decode_message 

114 super()._fake_start(broker, *args, **kwargs) 

115 

116 @contextmanager 

117 def _patch_producer(self, broker: MQTTBroker) -> Iterator[None]: 

118 fake_producer = FakeProducer(broker) 

119 with change_producer(broker.config.broker_config, fake_producer): 

120 yield 

121 

122 async def _fake_connect( # type: ignore[override] 

123 self, 

124 broker: MQTTBroker, 

125 *args: Any, 

126 **kwargs: Any, 

127 ) -> MagicMock: 

128 fake_client = MagicMock() 

129 fake_client.subscribe.return_value = _BlockingSubscription() 

130 # Wire fake client into config so that dynamically-added subscribers 

131 # can call start() without a real MQTT connection. 

132 broker.config.broker_config._client = fake_client 

133 return fake_client 

134 

135 

136class FakeProducer(ZmqttBaseProducer): 

137 """In-memory producer that routes messages directly to matching subscribers. 

138 

139 Encodes messages in the wire format matching the broker's configured 

140 MQTT version: V311 envelope for 3.1.1, PublishProperties for 5.0. 

141 """ 

142 

143 def __init__(self, broker: MQTTBroker) -> None: 

144 self.broker = broker 

145 self.serializer: SerializerProto | None = None 

146 

147 version = _broker_version(broker) 

148 default = _parser_for_version(version) 

149 self._parser = ParserComposition(broker._parser, default.parse_message) 

150 self._decoder = ParserComposition(broker._decoder, default.decode_message) 

151 

152 @property 

153 def _version(self) -> Literal["3.1.1", "5.0"]: 

154 return _broker_version(self.broker) 

155 

156 @override 

157 async def publish(self, cmd: MQTTPublishCommand) -> None: 

158 msg = build_message( 

159 message=cmd.body, 

160 topic=cmd.destination, 

161 version=self._version, 

162 qos=cmd.qos, 

163 retain=cmd.retain, 

164 reply_to=cmd.reply_to, 

165 correlation_id=cmd.correlation_id, 

166 headers=cmd.headers, 

167 serializer=self.broker.config.fd_config._serializer, 

168 ) 

169 

170 # For shared subscriptions, only deliver to one subscriber per group 

171 seen_shared_groups: set[str] = set() 

172 

173 for handler in cast("list[MQTTBaseSubscriber]", self.broker.subscribers): 

174 handler_topic = handler.topic 

175 if not mqtt_topic_matches(handler_topic, cmd.destination): 

176 continue 

177 

178 if handler_topic.startswith("$share/"): 

179 _, group, _ = handler_topic.split("/", 2) 

180 if group in seen_shared_groups: 180 ↛ 181line 180 didn't jump to line 181 because the condition on line 180 was never true

181 continue 

182 seen_shared_groups.add(group) 

183 

184 await handler.process_message(msg) 

185 

186 @override 

187 async def request(self, cmd: MQTTPublishCommand) -> "zmqtt.Message": 

188 msg = build_message( 

189 message=cmd.body, 

190 topic=cmd.destination, 

191 version=self._version, 

192 qos=cmd.qos, 

193 retain=cmd.retain, 

194 correlation_id=cmd.correlation_id, 

195 headers=cmd.headers, 

196 serializer=self.broker.config.fd_config._serializer, 

197 ) 

198 

199 for handler in cast("list[MQTTBaseSubscriber]", self.broker.subscribers): 199 ↛ 215line 199 didn't jump to line 215 because the loop on line 199 didn't complete

200 if not mqtt_topic_matches(handler.topic, cmd.destination): 

201 continue 

202 

203 with anyio.fail_after(cmd.timeout or 30.0): 

204 result = await handler.process_message(msg) 

205 

206 return build_message( 

207 message=result.body, 

208 topic=cmd.destination, 

209 version=self._version, 

210 correlation_id=result.correlation_id, 

211 headers=result.headers, 

212 serializer=self.broker.config.fd_config._serializer, 

213 ) 

214 

215 raise SubscriberNotFound 

216 

217 

218def build_message( 

219 message: "SendableMessage", 

220 topic: str, 

221 *, 

222 version: Literal["3.1.1", "5.0"] = "5.0", 

223 qos: int = 0, 

224 retain: bool = False, 

225 reply_to: str = "", 

226 correlation_id: str | None = None, 

227 headers: dict[str, str] | None = None, 

228 serializer: Optional["SerializerProto"] = None, 

229) -> zmqtt.Message: 

230 """Build a fake ``zmqtt.Message`` from publish parameters. 

231 

232 For MQTT 5.0 uses *PublishProperties* to carry metadata so that 

233 ``MQTTParserV5`` can extract them transparently. 

234 For MQTT 3.1.1 returns a plain message with raw payload only. 

235 """ 

236 payload, content_type = encode_message(message, serializer=serializer) 

237 

238 if version == "3.1.1": 

239 return zmqtt.Message( 

240 topic=topic, 

241 payload=payload, 

242 qos=zmqtt.QoS(qos), 

243 retain=retain, 

244 ) 

245 

246 user_props: list[tuple[str, str]] = list((headers or {}).items()) 

247 

248 properties = zmqtt.PublishProperties( 

249 content_type=content_type or None, 

250 response_topic=reply_to or None, 

251 correlation_data=correlation_id.encode() if correlation_id else None, 

252 user_properties=tuple(user_props), 

253 ) 

254 

255 return zmqtt.Message( 

256 topic=topic, 

257 payload=payload, 

258 qos=zmqtt.QoS(qos), 

259 retain=retain, 

260 properties=properties, 

261 )