Coverage for faststream / mqtt / subscriber / usecase.py: 89%

86 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-08 01:48 +0000

1import warnings 

2from abc import abstractmethod 

3from collections.abc import AsyncIterator, Sequence 

4from contextlib import suppress 

5from typing import TYPE_CHECKING, Any 

6 

7import anyio 

8import zmqtt 

9from typing_extensions import override 

10 

11from faststream._internal.endpoint.subscriber import SubscriberUsecase 

12from faststream._internal.endpoint.subscriber.mixins import ConcurrentMixin, TasksMixin 

13from faststream._internal.endpoint.utils import process_msg 

14from faststream.middlewares import AckPolicy 

15from faststream.mqtt.parser import MQTTBaseParser, MQTTParserV5, MQTTParserV311 

16from faststream.mqtt.publisher.fake import MQTTFakePublisher 

17 

18if TYPE_CHECKING: 

19 from faststream._internal.endpoint.publisher import PublisherProto 

20 from faststream._internal.endpoint.subscriber import SubscriberSpecification 

21 from faststream._internal.endpoint.subscriber.call_item import CallsCollection 

22 from faststream.message import StreamMessage 

23 from faststream.mqtt.broker.config import MQTTBrokerConfig 

24 from faststream.mqtt.message import MQTTMessage 

25 from faststream.mqtt.subscriber.config import MQTTSubscriberConfig 

26 

27 

28class MQTTBaseSubscriber(TasksMixin, SubscriberUsecase[zmqtt.Message]): 

29 """Base class for all MQTT subscribers.""" 

30 

31 _outer_config: "MQTTBrokerConfig" 

32 

33 def __init__( 

34 self, 

35 config: "MQTTSubscriberConfig", 

36 specification: "SubscriberSpecification[Any, Any]", 

37 calls: "CallsCollection[zmqtt.Message]", 

38 ) -> None: 

39 # version may not be available yet when subscriber is created on a router 

40 # before include_router is called; default to V5 and re-resolve in start(). 

41 parser: MQTTBaseParser 

42 if getattr(config._outer_config, "version", "5.0") == "3.1.1": 

43 parser = MQTTParserV311() 

44 else: 

45 parser = MQTTParserV5() 

46 config.parser = parser.parse_message 

47 config.decoder = parser.decode_message 

48 super().__init__(config, specification, calls) 

49 self._topic = config.topic 

50 self._shared = config.shared 

51 self._qos = config.qos 

52 self._subscription: zmqtt.Subscription | None = None 

53 

54 if config.ack_policy is AckPolicy.NACK_ON_ERROR: 54 ↛ 55line 54 didn't jump to line 55 because the condition on line 54 was never true

55 warnings.warn( 

56 "MQTT has no nack primitive; with NACK_ON_ERROR, " 

57 "on error QoS 1/2 messages will not be acknowledged " 

58 "and the broker will redeliver them.", 

59 RuntimeWarning, 

60 stacklevel=3, 

61 ) 

62 

63 @property 

64 def topic(self) -> str: 

65 full = f"{self._outer_config.prefix}{self._topic}" 

66 return f"$share/{self._shared}/{full}" if self._shared else full 

67 

68 def _make_response_publisher( 

69 self, 

70 message: "StreamMessage[Any]", 

71 ) -> Sequence["PublisherProto"]: 

72 return ( 

73 MQTTFakePublisher( 

74 producer=self._outer_config.producer, 

75 topic=message.reply_to, 

76 ), 

77 ) 

78 

79 @staticmethod 

80 def build_log_context( 

81 message: "StreamMessage[zmqtt.Message] | None", 

82 topic: str = "", 

83 ) -> dict[str, str]: 

84 return { 

85 "topic": topic, 

86 "message_id": getattr(message, "message_id", ""), 

87 } 

88 

89 def get_log_context( 

90 self, 

91 message: "StreamMessage[zmqtt.Message] | None", 

92 ) -> dict[str, str]: 

93 return self.build_log_context(message=message, topic=self.topic) 

94 

95 @override 

96 async def start(self) -> None: 

97 # Re-resolve the parser now that _outer_config is fully composed 

98 # (i.e. include_router has been called and the broker's MQTTBrokerConfig 

99 # is reachable through the config chain). 

100 parser: MQTTBaseParser 

101 if getattr(self._outer_config, "version", "5.0") == "3.1.1": 

102 parser = MQTTParserV311() 

103 else: 

104 parser = MQTTParserV5() 

105 self._parser = parser.parse_message 

106 self._decoder = parser.decode_message 

107 

108 await super().start() 

109 

110 if self.calls: 

111 await self._create_subscription() 

112 self.add_task(self._consume_loop) 

113 

114 self._post_start() 

115 

116 @override 

117 async def stop(self) -> None: 

118 await super().stop() 

119 if self._subscription is not None: 

120 with suppress(Exception): 

121 await self._subscription.stop() 

122 self._subscription = None 

123 

124 async def _create_subscription(self) -> None: 

125 auto_ack = self.ack_policy is AckPolicy.ACK_FIRST 

126 self._subscription = self._outer_config.client.subscribe( 

127 self.topic, 

128 qos=zmqtt.QoS(self._qos), 

129 auto_ack=auto_ack, 

130 ) 

131 await self._subscription.start() 

132 

133 @override 

134 async def get_one( 

135 self, 

136 *, 

137 timeout: float = 5.0, 

138 ) -> "StreamMessage[zmqtt.Message] | None": 

139 assert not self.calls, ( 

140 "You can't use `get_one` method if subscriber has registered handlers." 

141 ) 

142 

143 if self._subscription is None: 143 ↛ 152line 143 didn't jump to line 152 because the condition on line 143 was always true

144 auto_ack = self.ack_policy is AckPolicy.ACK_FIRST 

145 self._subscription = self._outer_config.client.subscribe( 

146 self.topic, 

147 qos=zmqtt.QoS(self._qos), 

148 auto_ack=auto_ack, 

149 ) 

150 await self._subscription.start() 

151 

152 async_parser, async_decoder = self._get_parser_and_decoder() 

153 

154 raw_msg: zmqtt.Message | None = None 

155 with anyio.move_on_after(timeout): 

156 raw_msg = await self._subscription.get_message() 

157 

158 context = self._outer_config.fd_config.context 

159 return await process_msg( 

160 msg=raw_msg, 

161 middlewares=(m(raw_msg, context=context) for m in self._broker_middlewares), 

162 parser=async_parser, 

163 decoder=async_decoder, 

164 ) 

165 

166 @override 

167 async def __aiter__(self) -> AsyncIterator["StreamMessage[zmqtt.Message]"]: # type: ignore[override] 

168 if self._subscription is None: 168 ↛ 171line 168 didn't jump to line 171 because the condition on line 168 was always true

169 await self._create_subscription() 

170 

171 assert self._subscription is not None 

172 context = self._outer_config.fd_config.context 

173 async_parser, async_decoder = self._get_parser_and_decoder() 

174 async for raw_msg in self._subscription: 174 ↛ exitline 174 didn't return from function '__aiter__' because the loop on line 174 didn't complete

175 msg: MQTTMessage = await process_msg( # type: ignore[assignment] 

176 msg=raw_msg, 

177 middlewares=( 

178 m(raw_msg, context=context) for m in self._broker_middlewares 

179 ), 

180 parser=async_parser, 

181 decoder=async_decoder, 

182 ) 

183 yield msg 

184 

185 @abstractmethod 

186 async def _consume_loop(self) -> None: 

187 raise NotImplementedError 

188 

189 

190class MQTTDefaultSubscriber(MQTTBaseSubscriber): 

191 """Sequential MQTT subscriber — processes one message at a time.""" 

192 

193 async def _consume_loop(self) -> None: 

194 assert self._subscription is not None 

195 async for msg in self._subscription: 

196 await self.consume(msg) 

197 

198 

199class MQTTConcurrentSubscriber(ConcurrentMixin[zmqtt.Message], MQTTBaseSubscriber): 

200 """Concurrent MQTT subscriber — processes up to max_workers messages in parallel.""" 

201 

202 @override 

203 async def start(self) -> None: 

204 await super().start() 

205 self.start_consume_task() 

206 

207 async def _consume_loop(self) -> None: 

208 assert self._subscription is not None 

209 async for msg in self._subscription: 

210 await self._put_msg(msg)