Coverage for faststream / redis / publisher / producer.py: 93%

47 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-08 01:48 +0000

1from contextlib import suppress 

2from typing import TYPE_CHECKING, Any, Optional, cast 

3 

4import anyio 

5from typing_extensions import override 

6 

7from faststream._internal.endpoint.utils import ParserComposition 

8from faststream._internal.producer import ProducerProto 

9from faststream._internal.utils.nuid import NUID 

10from faststream.redis.message import DATA_KEY 

11from faststream.redis.parser import RedisPubSubParser, SimpleParserConfig 

12from faststream.redis.response import DestinationType, RedisPublishCommand 

13 

14if TYPE_CHECKING: 

15 from fast_depends.library.serializer import SerializerProto 

16 

17 from faststream._internal.types import CustomCallable 

18 from faststream.redis.configs import ConnectionState 

19 from faststream.redis.parser import MessageFormat 

20 

21 

22class RedisFastProducer(ProducerProto[RedisPublishCommand]): 

23 """A class to represent a Redis producer.""" 

24 

25 _decoder: "ParserComposition" 

26 _parser: "ParserComposition" 

27 

28 def __init__( 

29 self, 

30 connection: "ConnectionState", 

31 parser: Optional["CustomCallable"], 

32 decoder: Optional["CustomCallable"], 

33 message_format: type["MessageFormat"], 

34 serializer: Optional["SerializerProto"], 

35 ) -> None: 

36 self._connection = connection 

37 

38 default = RedisPubSubParser(SimpleParserConfig(message_format)) 

39 self._parser = ParserComposition( 

40 parser, 

41 default.parse_message, 

42 ) 

43 self._decoder = ParserComposition( 

44 decoder, 

45 default.decode_message, 

46 ) 

47 self.serializer = serializer 

48 

49 @override 

50 async def publish(self, cmd: "RedisPublishCommand") -> int | bytes: 

51 msg = cmd.message_format.encode( 

52 message=cmd.body, 

53 reply_to=cmd.reply_to, 

54 headers=cmd.headers, 

55 correlation_id=cmd.correlation_id or "", 

56 serializer=self.serializer, 

57 ) 

58 

59 return await self.__publish(msg, cmd) 

60 

61 @override 

62 async def request(self, cmd: "RedisPublishCommand") -> "Any": 

63 nuid = NUID() 

64 reply_to = str(nuid.next(), "utf-8") 

65 psub = self._connection.client.pubsub() 

66 

67 try: 

68 await psub.subscribe(reply_to) 

69 

70 msg = cmd.message_format.encode( 

71 message=cmd.body, 

72 reply_to=reply_to, 

73 headers=cmd.headers, 

74 correlation_id=cmd.correlation_id or "", 

75 serializer=self.serializer, 

76 ) 

77 

78 await self.__publish(msg, cmd) 

79 

80 with anyio.fail_after(cmd.timeout) as scope: 

81 # skip subscribe message 

82 await psub.get_message( 

83 ignore_subscribe_messages=True, 

84 timeout=cmd.timeout or 0.0, 

85 ) 

86 

87 # get real response 

88 response_msg = await psub.get_message( 

89 ignore_subscribe_messages=True, 

90 timeout=cmd.timeout or 0.0, 

91 ) 

92 

93 if scope.cancel_called: 93 ↛ 94line 93 didn't jump to line 94 because the condition on line 93 was never true

94 raise TimeoutError 

95 

96 return response_msg 

97 

98 finally: 

99 with suppress(Exception): 

100 await psub.unsubscribe() 

101 await psub.aclose() # type: ignore[attr-defined] 

102 

103 @override 

104 async def publish_batch(self, cmd: "RedisPublishCommand") -> int: 

105 batch = [ 

106 cmd.message_format.encode( 

107 message=msg, 

108 correlation_id=cmd.correlation_id or "", 

109 reply_to=cmd.reply_to, 

110 headers=cmd.headers, 

111 serializer=self.serializer, 

112 ) 

113 for msg in cmd.batch_bodies 

114 ] 

115 

116 connection = cmd.pipeline or self._connection.client 

117 return await connection.rpush(cmd.destination, *batch) 

118 

119 async def __publish( 

120 self, 

121 msg: bytes, 

122 cmd: "RedisPublishCommand", 

123 ) -> int | bytes: 

124 connection = cmd.pipeline or self._connection.client 

125 

126 if cmd.destination_type is DestinationType.Channel: 

127 return await connection.publish(cmd.destination, msg) 

128 

129 if cmd.destination_type is DestinationType.List: 

130 return await connection.rpush(cmd.destination, msg) 

131 

132 if cmd.destination_type is DestinationType.Stream: 132 ↛ 142line 132 didn't jump to line 142 because the condition on line 132 was always true

133 return cast( 

134 "bytes", 

135 await connection.xadd( 

136 name=cmd.destination, 

137 fields={DATA_KEY: msg}, 

138 maxlen=cmd.maxlen, 

139 ), 

140 ) 

141 

142 error_msg = "unreachable" 

143 raise AssertionError(error_msg) 

144 

145 def connect(self, serializer: Optional["SerializerProto"] = None) -> None: 

146 self.serializer = serializer