Coverage for faststream / rabbit / publisher / producer.py: 92%

65 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-08 01:48 +0000

1from abc import abstractmethod 

2from typing import ( 

3 TYPE_CHECKING, 

4 Optional, 

5 Protocol, 

6 cast, 

7) 

8 

9import anyio 

10from typing_extensions import Unpack, override 

11 

12from faststream._internal.endpoint.utils import ParserComposition 

13from faststream._internal.producer import ProducerProto 

14from faststream.exceptions import FeatureNotSupportedException, IncorrectState 

15from faststream.rabbit.parser import AioPikaParser 

16from faststream.rabbit.response import RabbitPublishCommand 

17from faststream.rabbit.schemas import RABBIT_REPLY, RabbitExchange 

18 

19if TYPE_CHECKING: 

20 from types import TracebackType 

21 

22 import aiormq 

23 from aio_pika import IncomingMessage, RobustQueue 

24 from aio_pika.abc import AbstractIncomingMessage, TimeoutType 

25 from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream 

26 from fast_depends.library.serializer import SerializerProto 

27 

28 from faststream._internal.types import ( 

29 AsyncCallable, 

30 CustomCallable, 

31 ) 

32 from faststream.rabbit.helpers import RabbitDeclarer 

33 from faststream.rabbit.types import AioPikaSendableMessage 

34 

35 from .options import MessageOptions 

36 

37 

38class LockState(Protocol): 

39 @property 

40 def lock(self) -> "anyio.Lock": ... 

41 

42 

43class LockUnset: 

44 __slots__ = () 

45 

46 @property 

47 def lock(self) -> "anyio.Lock": 

48 msg = "You should call `producer.connect()` method at first." 

49 raise IncorrectState(msg) 

50 

51 

52class RealLock: 

53 __slots__ = ("lock",) 

54 

55 def __init__(self) -> None: 

56 self.lock = anyio.Lock() 

57 

58 

59class AioPikaFastProducer(ProducerProto[RabbitPublishCommand]): 

60 def connect(self, serializer: Optional["SerializerProto"] = None) -> None: ... 

61 

62 def disconnect(self) -> None: ... 

63 

64 @abstractmethod 

65 async def publish( 

66 self, 

67 cmd: "RabbitPublishCommand", 

68 ) -> Optional["aiormq.abc.ConfirmationFrameType"]: ... 

69 

70 @abstractmethod 

71 async def request(self, cmd: "RabbitPublishCommand") -> "IncomingMessage": ... 

72 

73 @override 

74 async def publish_batch(self, cmd: "RabbitPublishCommand") -> None: 

75 msg = "RabbitMQ doesn't support publishing in batches." 

76 raise FeatureNotSupportedException(msg) 

77 

78 

79class FakeAioPikaFastProducer(AioPikaFastProducer): 

80 def __bool__(self) -> bool: 

81 return False 

82 

83 def connect(self, serializer: Optional["SerializerProto"] = None) -> None: 

84 raise NotImplementedError 

85 

86 def disconnect(self) -> None: 

87 raise NotImplementedError 

88 

89 @override 

90 async def publish( 

91 self, 

92 cmd: "RabbitPublishCommand", 

93 ) -> Optional["aiormq.abc.ConfirmationFrameType"]: 

94 raise NotImplementedError 

95 

96 @override 

97 async def request(self, cmd: "RabbitPublishCommand") -> "IncomingMessage": 

98 raise NotImplementedError 

99 

100 

101class AioPikaFastProducerImpl(AioPikaFastProducer): 

102 """A class for fast producing messages using aio-pika.""" 

103 

104 _decoder: "AsyncCallable" 

105 _parser: "AsyncCallable" 

106 

107 def __init__( 

108 self, 

109 *, 

110 declarer: "RabbitDeclarer", 

111 parser: Optional["CustomCallable"], 

112 decoder: Optional["CustomCallable"], 

113 ) -> None: 

114 self.declarer = declarer 

115 

116 self.__lock: LockState = LockUnset() 

117 self.serializer: SerializerProto | None = None 

118 

119 default_parser = AioPikaParser() 

120 self._parser = ParserComposition(parser, default_parser.parse_message) 

121 self._decoder = ParserComposition(decoder, default_parser.decode_message) 

122 

123 def connect(self, serializer: Optional["SerializerProto"] = None) -> None: 

124 """Lock initialization. 

125 

126 Should be called in async context due `anyio.Lock` object can't be created outside event loop. 

127 """ 

128 self.serializer = serializer 

129 self.__lock = RealLock() 

130 

131 def disconnect(self) -> None: 

132 self.__lock = LockUnset() 

133 

134 @override 

135 async def publish( 

136 self, 

137 cmd: "RabbitPublishCommand", 

138 ) -> Optional["aiormq.abc.ConfirmationFrameType"]: 

139 return await self._publish( 

140 message=cmd.body, 

141 exchange=cmd.exchange, 

142 routing_key=cmd.destination, 

143 reply_to=cmd.reply_to, 

144 headers=cmd.headers, 

145 correlation_id=cmd.correlation_id, 

146 **cmd.publish_options, 

147 **cmd.message_options, 

148 ) 

149 

150 @override 

151 async def request(self, cmd: "RabbitPublishCommand") -> "IncomingMessage": 

152 async with _RPCCallback( 

153 self.__lock.lock, 

154 await self.declarer.declare_queue(RABBIT_REPLY), 

155 ) as response_queue: 

156 with anyio.fail_after(cmd.timeout): 

157 await self._publish( 

158 message=cmd.body, 

159 exchange=cmd.exchange, 

160 routing_key=cmd.destination, 

161 reply_to=RABBIT_REPLY.name, 

162 headers=cmd.headers, 

163 correlation_id=cmd.correlation_id, 

164 **cmd.publish_options, 

165 **cmd.message_options, 

166 ) 

167 return await response_queue.receive() 

168 

169 async def _publish( 

170 self, 

171 message: "AioPikaSendableMessage", 

172 *, 

173 exchange: "RabbitExchange", 

174 routing_key: str, 

175 mandatory: bool = True, 

176 immediate: bool = False, 

177 timeout: "TimeoutType" = None, 

178 **message_options: Unpack["MessageOptions"], 

179 ) -> Optional["aiormq.abc.ConfirmationFrameType"]: 

180 message = AioPikaParser.encode_message( 

181 message=message, serializer=self.serializer, **message_options 

182 ) 

183 

184 exchange_obj = await self.declarer.declare_exchange( 

185 exchange=exchange, 

186 declare=False, 

187 ) 

188 

189 return await exchange_obj.publish( 

190 message=message, 

191 routing_key=routing_key, 

192 mandatory=mandatory, 

193 immediate=immediate, 

194 timeout=timeout, 

195 ) 

196 

197 

198class _RPCCallback: 

199 """A class provides an RPC lock.""" 

200 

201 def __init__(self, lock: "anyio.Lock", callback_queue: "RobustQueue") -> None: 

202 self.lock = lock 

203 self.queue = callback_queue 

204 

205 async def __aenter__(self) -> "MemoryObjectReceiveStream[IncomingMessage]": 

206 send_response_stream: MemoryObjectSendStream[AbstractIncomingMessage] 

207 receive_response_stream: MemoryObjectReceiveStream[AbstractIncomingMessage] 

208 

209 ( 

210 send_response_stream, 

211 receive_response_stream, 

212 ) = anyio.create_memory_object_stream(max_buffer_size=1) 

213 await self.lock.acquire() 

214 

215 self.consumer_tag = await self.queue.consume( 

216 callback=send_response_stream.send, 

217 no_ack=True, 

218 ) 

219 

220 return cast( 

221 "MemoryObjectReceiveStream[IncomingMessage]", 

222 receive_response_stream, 

223 ) 

224 

225 async def __aexit__( 

226 self, 

227 exc_type: type[BaseException] | None = None, 

228 exc_val: BaseException | None = None, 

229 exc_tb: Optional["TracebackType"] = None, 

230 ) -> None: 

231 self.lock.release() 

232 await self.queue.cancel(self.consumer_tag)