Coverage for faststream / redis / publisher / producer.py: 93%
47 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-08 01:48 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-08 01:48 +0000
1from contextlib import suppress
2from typing import TYPE_CHECKING, Any, Optional, cast
4import anyio
5from typing_extensions import override
7from faststream._internal.endpoint.utils import ParserComposition
8from faststream._internal.producer import ProducerProto
9from faststream._internal.utils.nuid import NUID
10from faststream.redis.message import DATA_KEY
11from faststream.redis.parser import RedisPubSubParser, SimpleParserConfig
12from faststream.redis.response import DestinationType, RedisPublishCommand
14if TYPE_CHECKING:
15 from fast_depends.library.serializer import SerializerProto
17 from faststream._internal.types import CustomCallable
18 from faststream.redis.configs import ConnectionState
19 from faststream.redis.parser import MessageFormat
22class RedisFastProducer(ProducerProto[RedisPublishCommand]):
23 """A class to represent a Redis producer."""
25 _decoder: "ParserComposition"
26 _parser: "ParserComposition"
28 def __init__(
29 self,
30 connection: "ConnectionState",
31 parser: Optional["CustomCallable"],
32 decoder: Optional["CustomCallable"],
33 message_format: type["MessageFormat"],
34 serializer: Optional["SerializerProto"],
35 ) -> None:
36 self._connection = connection
38 default = RedisPubSubParser(SimpleParserConfig(message_format))
39 self._parser = ParserComposition(
40 parser,
41 default.parse_message,
42 )
43 self._decoder = ParserComposition(
44 decoder,
45 default.decode_message,
46 )
47 self.serializer = serializer
49 @override
50 async def publish(self, cmd: "RedisPublishCommand") -> int | bytes:
51 msg = cmd.message_format.encode(
52 message=cmd.body,
53 reply_to=cmd.reply_to,
54 headers=cmd.headers,
55 correlation_id=cmd.correlation_id or "",
56 serializer=self.serializer,
57 )
59 return await self.__publish(msg, cmd)
61 @override
62 async def request(self, cmd: "RedisPublishCommand") -> "Any":
63 nuid = NUID()
64 reply_to = str(nuid.next(), "utf-8")
65 psub = self._connection.client.pubsub()
67 try:
68 await psub.subscribe(reply_to)
70 msg = cmd.message_format.encode(
71 message=cmd.body,
72 reply_to=reply_to,
73 headers=cmd.headers,
74 correlation_id=cmd.correlation_id or "",
75 serializer=self.serializer,
76 )
78 await self.__publish(msg, cmd)
80 with anyio.fail_after(cmd.timeout) as scope:
81 # skip subscribe message
82 await psub.get_message(
83 ignore_subscribe_messages=True,
84 timeout=cmd.timeout or 0.0,
85 )
87 # get real response
88 response_msg = await psub.get_message(
89 ignore_subscribe_messages=True,
90 timeout=cmd.timeout or 0.0,
91 )
93 if scope.cancel_called: 93 ↛ 94line 93 didn't jump to line 94 because the condition on line 93 was never true
94 raise TimeoutError
96 return response_msg
98 finally:
99 with suppress(Exception):
100 await psub.unsubscribe()
101 await psub.aclose() # type: ignore[attr-defined]
103 @override
104 async def publish_batch(self, cmd: "RedisPublishCommand") -> int:
105 batch = [
106 cmd.message_format.encode(
107 message=msg,
108 correlation_id=cmd.correlation_id or "",
109 reply_to=cmd.reply_to,
110 headers=cmd.headers,
111 serializer=self.serializer,
112 )
113 for msg in cmd.batch_bodies
114 ]
116 connection = cmd.pipeline or self._connection.client
117 return await connection.rpush(cmd.destination, *batch)
119 async def __publish(
120 self,
121 msg: bytes,
122 cmd: "RedisPublishCommand",
123 ) -> int | bytes:
124 connection = cmd.pipeline or self._connection.client
126 if cmd.destination_type is DestinationType.Channel:
127 return await connection.publish(cmd.destination, msg)
129 if cmd.destination_type is DestinationType.List:
130 return await connection.rpush(cmd.destination, msg)
132 if cmd.destination_type is DestinationType.Stream: 132 ↛ 142line 132 didn't jump to line 142 because the condition on line 132 was always true
133 return cast(
134 "bytes",
135 await connection.xadd(
136 name=cmd.destination,
137 fields={DATA_KEY: msg},
138 maxlen=cmd.maxlen,
139 ),
140 )
142 error_msg = "unreachable"
143 raise AssertionError(error_msg)
145 def connect(self, serializer: Optional["SerializerProto"] = None) -> None:
146 self.serializer = serializer