Coverage for faststream / redis / testing.py: 98%
135 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-08 01:48 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-08 01:48 +0000
1import re
2from collections.abc import Iterator, Sequence
3from contextlib import ExitStack, contextmanager
4from typing import (
5 TYPE_CHECKING,
6 Any,
7 Optional,
8 Protocol,
9 Union,
10 cast,
11)
12from unittest.mock import AsyncMock, MagicMock
14import anyio
15from typing_extensions import TypedDict, override
17from faststream._internal.endpoint.utils import ParserComposition
18from faststream._internal.testing.broker import TestBroker, change_producer
19from faststream.exceptions import SetupError, SubscriberNotFound
20from faststream.message import gen_cor_id
21from faststream.redis.broker.broker import RedisBroker
22from faststream.redis.message import (
23 BatchListMessage,
24 BatchStreamMessage,
25 DefaultListMessage,
26 DefaultStreamMessage,
27 PubSubMessage,
28 bDATA_KEY,
29)
30from faststream.redis.parser import MessageFormat, ParserConfig, RedisPubSubParser
31from faststream.redis.publisher.producer import RedisFastProducer
32from faststream.redis.response import DestinationType, RedisPublishCommand
33from faststream.redis.schemas import INCORRECT_SETUP_MSG
34from faststream.redis.subscriber.usecases.channel_subscriber import ChannelSubscriber
35from faststream.redis.subscriber.usecases.list_subscriber import _ListHandlerMixin
36from faststream.redis.subscriber.usecases.stream_subscriber import _StreamHandlerMixin
38if TYPE_CHECKING:
39 from fast_depends.library.serializer import SerializerProto
41 from faststream._internal.basic_types import SendableMessage
42 from faststream.redis.publisher.usecase import LogicPublisher
43 from faststream.redis.subscriber.usecases.basic import LogicSubscriber
45__all__ = ("TestRedisBroker",)
48class TestRedisBroker(TestBroker[RedisBroker]):
49 """A class to test Redis brokers."""
51 @contextmanager
52 def _patch_producer(self, broker: RedisBroker) -> Iterator[None]:
53 with ExitStack() as es:
54 es.enter_context(
55 change_producer(
56 broker.config.broker_config, FakeProducer(broker, broker.config)
57 ),
58 )
60 for publisher in cast("list[LogicPublisher]", broker.publishers):
61 es.enter_context(
62 change_producer(publisher, FakeProducer(broker, publisher.config)),
63 )
65 yield
67 @staticmethod
68 def create_publisher_fake_subscriber(
69 broker: RedisBroker,
70 publisher: "LogicPublisher",
71 ) -> tuple["LogicSubscriber", bool]:
72 sub: LogicSubscriber | None = None
74 named_property = publisher.subscriber_property(name_only=True)
75 visitors = (ChannelVisitor(), ListVisitor(), StreamVisitor())
77 for handler in broker.subscribers: # pragma: no branch
78 handler = cast("LogicSubscriber", handler)
79 for visitor in visitors:
80 if visitor.visit(**named_property, sub=handler):
81 sub = handler
82 break
84 if sub is None:
85 is_real = False
86 sub_options = publisher.subscriber_property(name_only=False)
87 sub = broker.subscriber(**sub_options, persistent=False)
88 else:
89 is_real = True
91 return sub, is_real
93 @staticmethod
94 async def _fake_connect( # type: ignore[override]
95 broker: RedisBroker,
96 *args: Any,
97 **kwargs: Any,
98 ) -> AsyncMock:
99 connection = MagicMock()
101 pub_sub = AsyncMock()
103 async def get_msg(*args: Any, timeout: float, **kwargs: Any) -> None:
104 await anyio.sleep(timeout)
106 pub_sub.get_message = get_msg
108 connection.pubsub.side_effect = lambda: pub_sub
109 connection.aclose = AsyncMock()
111 connection.xack = AsyncMock()
112 connection.xdel = AsyncMock()
114 broker.config.broker_config.connection._client = connection
115 return connection
118class FakeProducer(RedisFastProducer):
119 def __init__(self, broker: RedisBroker, config: ParserConfig) -> None:
120 self.broker = broker
122 default = RedisPubSubParser(config)
124 self._parser = ParserComposition(
125 broker._parser,
126 default.parse_message,
127 )
128 self._decoder = ParserComposition(
129 broker._decoder,
130 default.decode_message,
131 )
133 @override
134 async def publish(self, cmd: "RedisPublishCommand") -> int | bytes:
135 body = build_message(
136 message=cmd.body,
137 reply_to=cmd.reply_to,
138 correlation_id=cmd.correlation_id or gen_cor_id(),
139 headers=cmd.headers,
140 message_format=cmd.message_format,
141 serializer=self.broker.config.fd_config._serializer,
142 )
144 destination = _make_destination_kwargs(cmd)
145 visitors = (ChannelVisitor(), ListVisitor(), StreamVisitor())
147 for handler in self.broker.subscribers: # pragma: no branch
148 handler = cast("LogicSubscriber", handler)
149 for visitor in visitors:
150 if visited_ch := visitor.visit(**destination, sub=handler):
151 msg = visitor.get_message(
152 visited_ch,
153 body,
154 handler, # type: ignore[arg-type]
155 )
157 await self._execute_handler(msg, handler)
159 return 0
161 @override
162 async def request(self, cmd: "RedisPublishCommand") -> "PubSubMessage":
163 body = build_message(
164 message=cmd.body,
165 correlation_id=cmd.correlation_id or gen_cor_id(),
166 headers=cmd.headers,
167 message_format=cmd.message_format,
168 serializer=self.broker.config.fd_config._serializer,
169 )
171 destination = _make_destination_kwargs(cmd)
172 visitors = (ChannelVisitor(), ListVisitor(), StreamVisitor())
174 for handler in self.broker.subscribers: # pragma: no branch
175 handler = cast("LogicSubscriber", handler)
176 for visitor in visitors:
177 if visited_ch := visitor.visit(**destination, sub=handler):
178 msg = visitor.get_message(
179 visited_ch,
180 body,
181 handler, # type: ignore[arg-type]
182 )
184 with anyio.fail_after(cmd.timeout):
185 return await self._execute_handler(msg, handler)
187 raise SubscriberNotFound
189 @override
190 async def publish_batch(self, cmd: "RedisPublishCommand") -> int:
191 data_to_send = [
192 build_message(
193 m,
194 correlation_id=cmd.correlation_id or gen_cor_id(),
195 headers=cmd.headers,
196 message_format=cmd.message_format,
197 serializer=self.broker.config.fd_config._serializer,
198 )
199 for m in cmd.batch_bodies
200 ]
202 visitor = ListVisitor()
203 for handler in self.broker.subscribers: # pragma: no branch
204 handler = cast("LogicSubscriber", handler)
205 if visitor.visit(list=cmd.destination, sub=handler):
206 casted_handler = cast("_ListHandlerMixin", handler)
208 if casted_handler.list_sub.batch: 208 ↛ 203line 208 didn't jump to line 203 because the condition on line 208 was always true
209 msg = visitor.get_message(
210 channel=cmd.destination,
211 body=data_to_send,
212 sub=casted_handler,
213 )
215 await self._execute_handler(msg, handler)
217 return 0
219 async def _execute_handler(
220 self,
221 msg: Any,
222 handler: "LogicSubscriber",
223 ) -> "PubSubMessage":
224 result = await handler.process_message(msg)
226 return PubSubMessage(
227 type="message",
228 data=build_message(
229 message=result.body,
230 headers=result.headers,
231 correlation_id=result.correlation_id or "",
232 message_format=handler.config.message_format,
233 serializer=self.broker.config.fd_config._serializer,
234 ),
235 channel="",
236 pattern=None,
237 )
240def build_message(
241 message: Union[Sequence["SendableMessage"], "SendableMessage"],
242 *,
243 correlation_id: str,
244 message_format: type["MessageFormat"],
245 reply_to: str = "",
246 headers: dict[str, Any] | None = None,
247 serializer: Optional["SerializerProto"] = None,
248) -> bytes:
249 return message_format.encode(
250 message=message,
251 reply_to=reply_to,
252 headers=headers,
253 correlation_id=correlation_id,
254 serializer=serializer,
255 )
258class Visitor(Protocol):
259 def visit(
260 self,
261 *,
262 channel: str | None,
263 list: str | None,
264 stream: str | None,
265 sub: "LogicSubscriber",
266 ) -> str | None: ...
268 def get_message(self, channel: str, body: Any, sub: "LogicSubscriber") -> Any: ...
271class ChannelVisitor(Visitor):
272 def visit(
273 self,
274 *,
275 sub: "LogicSubscriber",
276 channel: str | None = None,
277 list: str | None = None,
278 stream: str | None = None,
279 ) -> str | None:
280 if channel is None or not isinstance(sub, ChannelSubscriber):
281 return None
283 sub_channel = sub.channel
285 if (
286 sub_channel.pattern
287 and bool(
288 re.match(
289 sub_channel.name.replace(".", "\\.").replace("*", ".*"),
290 channel or "",
291 ),
292 )
293 ) or channel == sub_channel.name:
294 return channel
296 return None
298 def get_message( # type: ignore[override]
299 self,
300 channel: str,
301 body: Any,
302 sub: "ChannelSubscriber",
303 ) -> Any:
304 return PubSubMessage(
305 type="message",
306 data=body,
307 channel=channel,
308 pattern=sub.channel.pattern.encode() if sub.channel.pattern else None,
309 )
312class ListVisitor(Visitor):
313 def visit(
314 self,
315 *,
316 sub: "LogicSubscriber",
317 channel: str | None = None,
318 list: str | None = None,
319 stream: str | None = None,
320 ) -> str | None:
321 if list is None or not isinstance(sub, _ListHandlerMixin):
322 return None
324 if list == sub.list_sub.name:
325 return list
327 return None
329 def get_message( # type: ignore[override]
330 self,
331 channel: str,
332 body: Any,
333 sub: "_ListHandlerMixin",
334 ) -> Any:
335 if sub.list_sub.batch:
336 return BatchListMessage(
337 type="blist",
338 channel=channel,
339 data=body if isinstance(body, list) else [body],
340 )
342 return DefaultListMessage(
343 type="list",
344 channel=channel,
345 data=body,
346 )
349class StreamVisitor(Visitor):
350 def visit(
351 self,
352 *,
353 sub: "LogicSubscriber",
354 channel: str | None = None,
355 list: str | None = None,
356 stream: str | None = None,
357 ) -> str | None:
358 if stream is None or not isinstance(sub, _StreamHandlerMixin):
359 return None
361 if stream == sub.stream_sub.name:
362 return stream
364 return None
366 def get_message( # type: ignore[override]
367 self,
368 channel: str,
369 body: Any,
370 sub: "_StreamHandlerMixin",
371 ) -> Any:
372 if sub.stream_sub.batch:
373 return BatchStreamMessage(
374 type="bstream",
375 channel=channel,
376 data=[{bDATA_KEY: body}],
377 message_ids=[],
378 )
380 return DefaultStreamMessage(
381 type="stream",
382 channel=channel,
383 data={bDATA_KEY: body},
384 message_ids=[],
385 )
388class _DestinationKwargs(TypedDict, total=False):
389 channel: str
390 list: str
391 stream: str
394def _make_destination_kwargs(cmd: RedisPublishCommand) -> _DestinationKwargs:
395 destination: _DestinationKwargs = {}
396 if cmd.destination_type is DestinationType.Channel:
397 destination["channel"] = cmd.destination
398 if cmd.destination_type is DestinationType.List:
399 destination["list"] = cmd.destination
400 if cmd.destination_type is DestinationType.Stream:
401 destination["stream"] = cmd.destination
403 if len(destination) != 1: 403 ↛ 404line 403 didn't jump to line 404 because the condition on line 403 was never true
404 raise SetupError(INCORRECT_SETUP_MSG)
406 return destination