Coverage for faststream / nats / testing.py: 98%

92 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-08 01:48 +0000

1from collections.abc import Generator, Iterable, Iterator 

2from contextlib import ExitStack, contextmanager 

3from typing import TYPE_CHECKING, Any, Optional, cast 

4from unittest.mock import AsyncMock 

5 

6import anyio 

7from nats.aio.msg import Msg 

8from typing_extensions import override 

9 

10from faststream._internal.endpoint.utils import ParserComposition 

11from faststream._internal.testing.broker import TestBroker 

12from faststream.exceptions import SubscriberNotFound 

13from faststream.message import encode_message, gen_cor_id 

14from faststream.nats.broker import NatsBroker 

15from faststream.nats.parser import NatsParser 

16from faststream.nats.publisher.producer import NatsFastProducer 

17from faststream.nats.schemas.js_stream import is_subject_match_wildcard 

18 

19if TYPE_CHECKING: 

20 from fast_depends.library.serializer import SerializerProto 

21 

22 from faststream._internal.basic_types import SendableMessage 

23 from faststream._internal.configs.broker import ConfigComposition 

24 from faststream.nats.configs import NatsBrokerConfig 

25 from faststream.nats.publisher.usecase import LogicPublisher 

26 from faststream.nats.response import NatsPublishCommand 

27 from faststream.nats.subscriber.usecases.basic import LogicSubscriber 

28 

29__all__ = ("TestNatsBroker",) 

30 

31 

32@contextmanager 

33def change_producer( 

34 config: "ConfigComposition[NatsBrokerConfig]", 

35 producer: "NatsFastProducer", 

36) -> Generator[None, None, None]: 

37 old_producer, config.broker_config.producer = ( 

38 config.broker_config.producer, 

39 producer, 

40 ) 

41 old_js_producer, config.broker_config.js_producer = ( 

42 config.broker_config.js_producer, 

43 producer, 

44 ) 

45 yield 

46 config.broker_config.producer = old_producer 

47 config.broker_config.js_producer = old_js_producer 

48 

49 

50class TestNatsBroker(TestBroker[NatsBroker]): 

51 """A class to test NATS brokers.""" 

52 

53 @staticmethod 

54 def create_publisher_fake_subscriber( 

55 broker: NatsBroker, 

56 publisher: "LogicPublisher", 

57 ) -> tuple["LogicSubscriber[Any]", bool]: 

58 publisher_stream = publisher.stream.name if publisher.stream else None 

59 

60 sub: LogicSubscriber[Any] | None = None 

61 for handler in broker.subscribers: 

62 handler = cast("LogicSubscriber[Any]", handler) 

63 if _is_handler_matches(handler, publisher.subject, publisher_stream): 

64 sub = handler 

65 break 

66 

67 if sub is None: 

68 is_real = False 

69 sub = broker.subscriber( 

70 publisher.subject, persistent=False, stream=publisher_stream 

71 ) 

72 else: 

73 is_real = True 

74 

75 return sub, is_real 

76 

77 @contextmanager 

78 def _patch_producer(self, broker: NatsBroker) -> Iterator[None]: 

79 fake_producer = FakeProducer(broker) 

80 

81 with ExitStack() as es: 

82 es.enter_context(change_producer(broker.config, fake_producer)) 

83 yield 

84 

85 async def _fake_connect( 

86 self, 

87 broker: NatsBroker, 

88 *args: Any, 

89 **kwargs: Any, 

90 ) -> None: 

91 if not broker.config.connection_state: 

92 broker.config.connection_state.connect(AsyncMock(), AsyncMock()) 

93 

94 def _fake_start(self, broker: NatsBroker, *args: Any, **kwargs: Any) -> None: 

95 if not broker.config.connection_state: 

96 broker.config.connection_state.connect(AsyncMock(), AsyncMock()) 

97 return super()._fake_start(broker, *args, **kwargs) 

98 

99 

100class FakeProducer(NatsFastProducer): 

101 def __init__(self, broker: NatsBroker) -> None: 

102 self.broker = broker 

103 

104 default = NatsParser(pattern="", is_ack_disabled=True) 

105 self._parser = ParserComposition(broker._parser, default.parse_message) 

106 self._decoder = ParserComposition(broker._decoder, default.decode_message) 

107 

108 @override 

109 async def publish(self, cmd: "NatsPublishCommand") -> None: 

110 incoming = build_message( 

111 message=cmd.body, 

112 subject=cmd.destination, 

113 headers=cmd.headers, 

114 correlation_id=cmd.correlation_id, 

115 reply_to=cmd.reply_to, 

116 serializer=self.broker.config.fd_config._serializer, 

117 ) 

118 

119 for handler in _find_handler( 

120 cast("list[LogicSubscriber[Any]]", self.broker.subscribers), 

121 cmd.destination, 

122 cmd.stream, 

123 ): 

124 msg: list[PatchedMessage] | PatchedMessage 

125 

126 if (pull := getattr(handler, "pull_sub", None)) and pull.batch: 

127 msg = [incoming] 

128 else: 

129 msg = incoming 

130 

131 await self._execute_handler(msg, cmd.destination, handler) 

132 

133 @override 

134 async def request(self, cmd: "NatsPublishCommand") -> "PatchedMessage": 

135 incoming = build_message( 

136 message=cmd.body, 

137 subject=cmd.destination, 

138 headers=cmd.headers, 

139 correlation_id=cmd.correlation_id, 

140 serializer=self.broker.config.fd_config._serializer, 

141 ) 

142 

143 for handler in _find_handler( 

144 cast("list[LogicSubscriber[Any]]", self.broker.subscribers), 

145 cmd.destination, 

146 cmd.stream, 

147 ): 

148 msg: list[PatchedMessage] | PatchedMessage 

149 

150 if (pull := getattr(handler, "pull_sub", None)) and pull.batch: 

151 msg = [incoming] 

152 else: 

153 msg = incoming 

154 

155 with anyio.fail_after(cmd.timeout): 

156 return await self._execute_handler(msg, cmd.destination, handler) 

157 

158 raise SubscriberNotFound 

159 

160 async def _execute_handler( 

161 self, 

162 msg: Any, 

163 subject: str, 

164 handler: "LogicSubscriber[Any]", 

165 ) -> "PatchedMessage": 

166 result = await handler.process_message(msg) 

167 

168 return build_message( 

169 subject=subject, 

170 message=result.body, 

171 headers=result.headers, 

172 correlation_id=result.correlation_id, 

173 serializer=self.broker.config.fd_config._serializer, 

174 ) 

175 

176 

177def _find_handler( 

178 subscribers: Iterable["LogicSubscriber[Any]"], 

179 subject: str, 

180 stream: str | None = None, 

181) -> Generator["LogicSubscriber[Any]", None, None]: 

182 published_queues = set() 

183 for handler in subscribers: 

184 if _is_handler_matches(handler, subject, stream): 

185 if queue := getattr(handler, "queue", None): 

186 if queue in published_queues: 

187 continue 

188 else: 

189 published_queues.add(queue) 

190 yield handler 

191 

192 

193def _is_handler_matches( 

194 handler: "LogicSubscriber[Any]", 

195 subject: str, 

196 stream: str | None = None, 

197) -> bool: 

198 if stream: 

199 if not (handler_stream := getattr(handler, "stream", None)): 

200 return False 

201 

202 if stream != handler_stream.name: 202 ↛ 203line 202 didn't jump to line 203 because the condition on line 202 was never true

203 return False 

204 

205 if is_subject_match_wildcard(subject, handler.clear_subject): 

206 return True 

207 

208 for filter_subject in handler.filter_subjects or (): 

209 if is_subject_match_wildcard(subject, filter_subject): 

210 return True 

211 

212 return False 

213 

214 

215def build_message( 

216 message: "SendableMessage", 

217 subject: str, 

218 *, 

219 reply_to: str = "", 

220 correlation_id: str | None = None, 

221 headers: dict[str, str] | None = None, 

222 serializer: Optional["SerializerProto"] = None, 

223) -> "PatchedMessage": 

224 msg, content_type = encode_message(message, serializer=serializer) 

225 return PatchedMessage( 

226 _client=None, # type: ignore[arg-type] 

227 subject=subject, 

228 reply=reply_to, 

229 data=msg, 

230 headers={ 

231 "content-type": content_type or "", 

232 "correlation_id": correlation_id or gen_cor_id(), 

233 **(headers or {}), 

234 }, 

235 ) 

236 

237 

238class PatchedMessage(Msg): 

239 async def ack(self) -> None: 

240 pass 

241 

242 async def ack_sync( 

243 self, 

244 timeout: float = 1, 

245 ) -> "PatchedMessage": # pragma: no cover 

246 return self 

247 

248 async def nak(self, delay: float | None = None) -> None: 

249 pass 

250 

251 async def term(self) -> None: 

252 pass 

253 

254 async def in_progress(self) -> None: 

255 pass