Coverage for faststream / nats / testing.py: 98%
92 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-08 01:48 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-08 01:48 +0000
1from collections.abc import Generator, Iterable, Iterator
2from contextlib import ExitStack, contextmanager
3from typing import TYPE_CHECKING, Any, Optional, cast
4from unittest.mock import AsyncMock
6import anyio
7from nats.aio.msg import Msg
8from typing_extensions import override
10from faststream._internal.endpoint.utils import ParserComposition
11from faststream._internal.testing.broker import TestBroker
12from faststream.exceptions import SubscriberNotFound
13from faststream.message import encode_message, gen_cor_id
14from faststream.nats.broker import NatsBroker
15from faststream.nats.parser import NatsParser
16from faststream.nats.publisher.producer import NatsFastProducer
17from faststream.nats.schemas.js_stream import is_subject_match_wildcard
19if TYPE_CHECKING:
20 from fast_depends.library.serializer import SerializerProto
22 from faststream._internal.basic_types import SendableMessage
23 from faststream._internal.configs.broker import ConfigComposition
24 from faststream.nats.configs import NatsBrokerConfig
25 from faststream.nats.publisher.usecase import LogicPublisher
26 from faststream.nats.response import NatsPublishCommand
27 from faststream.nats.subscriber.usecases.basic import LogicSubscriber
29__all__ = ("TestNatsBroker",)
32@contextmanager
33def change_producer(
34 config: "ConfigComposition[NatsBrokerConfig]",
35 producer: "NatsFastProducer",
36) -> Generator[None, None, None]:
37 old_producer, config.broker_config.producer = (
38 config.broker_config.producer,
39 producer,
40 )
41 old_js_producer, config.broker_config.js_producer = (
42 config.broker_config.js_producer,
43 producer,
44 )
45 yield
46 config.broker_config.producer = old_producer
47 config.broker_config.js_producer = old_js_producer
50class TestNatsBroker(TestBroker[NatsBroker]):
51 """A class to test NATS brokers."""
53 @staticmethod
54 def create_publisher_fake_subscriber(
55 broker: NatsBroker,
56 publisher: "LogicPublisher",
57 ) -> tuple["LogicSubscriber[Any]", bool]:
58 publisher_stream = publisher.stream.name if publisher.stream else None
60 sub: LogicSubscriber[Any] | None = None
61 for handler in broker.subscribers:
62 handler = cast("LogicSubscriber[Any]", handler)
63 if _is_handler_matches(handler, publisher.subject, publisher_stream):
64 sub = handler
65 break
67 if sub is None:
68 is_real = False
69 sub = broker.subscriber(
70 publisher.subject, persistent=False, stream=publisher_stream
71 )
72 else:
73 is_real = True
75 return sub, is_real
77 @contextmanager
78 def _patch_producer(self, broker: NatsBroker) -> Iterator[None]:
79 fake_producer = FakeProducer(broker)
81 with ExitStack() as es:
82 es.enter_context(change_producer(broker.config, fake_producer))
83 yield
85 async def _fake_connect(
86 self,
87 broker: NatsBroker,
88 *args: Any,
89 **kwargs: Any,
90 ) -> None:
91 if not broker.config.connection_state:
92 broker.config.connection_state.connect(AsyncMock(), AsyncMock())
94 def _fake_start(self, broker: NatsBroker, *args: Any, **kwargs: Any) -> None:
95 if not broker.config.connection_state:
96 broker.config.connection_state.connect(AsyncMock(), AsyncMock())
97 return super()._fake_start(broker, *args, **kwargs)
100class FakeProducer(NatsFastProducer):
101 def __init__(self, broker: NatsBroker) -> None:
102 self.broker = broker
104 default = NatsParser(pattern="", is_ack_disabled=True)
105 self._parser = ParserComposition(broker._parser, default.parse_message)
106 self._decoder = ParserComposition(broker._decoder, default.decode_message)
108 @override
109 async def publish(self, cmd: "NatsPublishCommand") -> None:
110 incoming = build_message(
111 message=cmd.body,
112 subject=cmd.destination,
113 headers=cmd.headers,
114 correlation_id=cmd.correlation_id,
115 reply_to=cmd.reply_to,
116 serializer=self.broker.config.fd_config._serializer,
117 )
119 for handler in _find_handler(
120 cast("list[LogicSubscriber[Any]]", self.broker.subscribers),
121 cmd.destination,
122 cmd.stream,
123 ):
124 msg: list[PatchedMessage] | PatchedMessage
126 if (pull := getattr(handler, "pull_sub", None)) and pull.batch:
127 msg = [incoming]
128 else:
129 msg = incoming
131 await self._execute_handler(msg, cmd.destination, handler)
133 @override
134 async def request(self, cmd: "NatsPublishCommand") -> "PatchedMessage":
135 incoming = build_message(
136 message=cmd.body,
137 subject=cmd.destination,
138 headers=cmd.headers,
139 correlation_id=cmd.correlation_id,
140 serializer=self.broker.config.fd_config._serializer,
141 )
143 for handler in _find_handler(
144 cast("list[LogicSubscriber[Any]]", self.broker.subscribers),
145 cmd.destination,
146 cmd.stream,
147 ):
148 msg: list[PatchedMessage] | PatchedMessage
150 if (pull := getattr(handler, "pull_sub", None)) and pull.batch:
151 msg = [incoming]
152 else:
153 msg = incoming
155 with anyio.fail_after(cmd.timeout):
156 return await self._execute_handler(msg, cmd.destination, handler)
158 raise SubscriberNotFound
160 async def _execute_handler(
161 self,
162 msg: Any,
163 subject: str,
164 handler: "LogicSubscriber[Any]",
165 ) -> "PatchedMessage":
166 result = await handler.process_message(msg)
168 return build_message(
169 subject=subject,
170 message=result.body,
171 headers=result.headers,
172 correlation_id=result.correlation_id,
173 serializer=self.broker.config.fd_config._serializer,
174 )
177def _find_handler(
178 subscribers: Iterable["LogicSubscriber[Any]"],
179 subject: str,
180 stream: str | None = None,
181) -> Generator["LogicSubscriber[Any]", None, None]:
182 published_queues = set()
183 for handler in subscribers:
184 if _is_handler_matches(handler, subject, stream):
185 if queue := getattr(handler, "queue", None):
186 if queue in published_queues:
187 continue
188 else:
189 published_queues.add(queue)
190 yield handler
193def _is_handler_matches(
194 handler: "LogicSubscriber[Any]",
195 subject: str,
196 stream: str | None = None,
197) -> bool:
198 if stream:
199 if not (handler_stream := getattr(handler, "stream", None)):
200 return False
202 if stream != handler_stream.name: 202 ↛ 203line 202 didn't jump to line 203 because the condition on line 202 was never true
203 return False
205 if is_subject_match_wildcard(subject, handler.clear_subject):
206 return True
208 for filter_subject in handler.filter_subjects or ():
209 if is_subject_match_wildcard(subject, filter_subject):
210 return True
212 return False
215def build_message(
216 message: "SendableMessage",
217 subject: str,
218 *,
219 reply_to: str = "",
220 correlation_id: str | None = None,
221 headers: dict[str, str] | None = None,
222 serializer: Optional["SerializerProto"] = None,
223) -> "PatchedMessage":
224 msg, content_type = encode_message(message, serializer=serializer)
225 return PatchedMessage(
226 _client=None, # type: ignore[arg-type]
227 subject=subject,
228 reply=reply_to,
229 data=msg,
230 headers={
231 "content-type": content_type or "",
232 "correlation_id": correlation_id or gen_cor_id(),
233 **(headers or {}),
234 },
235 )
238class PatchedMessage(Msg):
239 async def ack(self) -> None:
240 pass
242 async def ack_sync(
243 self,
244 timeout: float = 1,
245 ) -> "PatchedMessage": # pragma: no cover
246 return self
248 async def nak(self, delay: float | None = None) -> None:
249 pass
251 async def term(self) -> None:
252 pass
254 async def in_progress(self) -> None:
255 pass