Coverage for faststream / redis / testing.py: 98%

135 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-08 01:48 +0000

1import re 

2from collections.abc import Iterator, Sequence 

3from contextlib import ExitStack, contextmanager 

4from typing import ( 

5 TYPE_CHECKING, 

6 Any, 

7 Optional, 

8 Protocol, 

9 Union, 

10 cast, 

11) 

12from unittest.mock import AsyncMock, MagicMock 

13 

14import anyio 

15from typing_extensions import TypedDict, override 

16 

17from faststream._internal.endpoint.utils import ParserComposition 

18from faststream._internal.testing.broker import TestBroker, change_producer 

19from faststream.exceptions import SetupError, SubscriberNotFound 

20from faststream.message import gen_cor_id 

21from faststream.redis.broker.broker import RedisBroker 

22from faststream.redis.message import ( 

23 BatchListMessage, 

24 BatchStreamMessage, 

25 DefaultListMessage, 

26 DefaultStreamMessage, 

27 PubSubMessage, 

28 bDATA_KEY, 

29) 

30from faststream.redis.parser import MessageFormat, ParserConfig, RedisPubSubParser 

31from faststream.redis.publisher.producer import RedisFastProducer 

32from faststream.redis.response import DestinationType, RedisPublishCommand 

33from faststream.redis.schemas import INCORRECT_SETUP_MSG 

34from faststream.redis.subscriber.usecases.channel_subscriber import ChannelSubscriber 

35from faststream.redis.subscriber.usecases.list_subscriber import _ListHandlerMixin 

36from faststream.redis.subscriber.usecases.stream_subscriber import _StreamHandlerMixin 

37 

38if TYPE_CHECKING: 

39 from fast_depends.library.serializer import SerializerProto 

40 

41 from faststream._internal.basic_types import SendableMessage 

42 from faststream.redis.publisher.usecase import LogicPublisher 

43 from faststream.redis.subscriber.usecases.basic import LogicSubscriber 

44 

45__all__ = ("TestRedisBroker",) 

46 

47 

48class TestRedisBroker(TestBroker[RedisBroker]): 

49 """A class to test Redis brokers.""" 

50 

51 @contextmanager 

52 def _patch_producer(self, broker: RedisBroker) -> Iterator[None]: 

53 with ExitStack() as es: 

54 es.enter_context( 

55 change_producer( 

56 broker.config.broker_config, FakeProducer(broker, broker.config) 

57 ), 

58 ) 

59 

60 for publisher in cast("list[LogicPublisher]", broker.publishers): 

61 es.enter_context( 

62 change_producer(publisher, FakeProducer(broker, publisher.config)), 

63 ) 

64 

65 yield 

66 

67 @staticmethod 

68 def create_publisher_fake_subscriber( 

69 broker: RedisBroker, 

70 publisher: "LogicPublisher", 

71 ) -> tuple["LogicSubscriber", bool]: 

72 sub: LogicSubscriber | None = None 

73 

74 named_property = publisher.subscriber_property(name_only=True) 

75 visitors = (ChannelVisitor(), ListVisitor(), StreamVisitor()) 

76 

77 for handler in broker.subscribers: # pragma: no branch 

78 handler = cast("LogicSubscriber", handler) 

79 for visitor in visitors: 

80 if visitor.visit(**named_property, sub=handler): 

81 sub = handler 

82 break 

83 

84 if sub is None: 

85 is_real = False 

86 sub_options = publisher.subscriber_property(name_only=False) 

87 sub = broker.subscriber(**sub_options, persistent=False) 

88 else: 

89 is_real = True 

90 

91 return sub, is_real 

92 

93 @staticmethod 

94 async def _fake_connect( # type: ignore[override] 

95 broker: RedisBroker, 

96 *args: Any, 

97 **kwargs: Any, 

98 ) -> AsyncMock: 

99 connection = MagicMock() 

100 

101 pub_sub = AsyncMock() 

102 

103 async def get_msg(*args: Any, timeout: float, **kwargs: Any) -> None: 

104 await anyio.sleep(timeout) 

105 

106 pub_sub.get_message = get_msg 

107 

108 connection.pubsub.side_effect = lambda: pub_sub 

109 connection.aclose = AsyncMock() 

110 

111 connection.xack = AsyncMock() 

112 connection.xdel = AsyncMock() 

113 

114 broker.config.broker_config.connection._client = connection 

115 return connection 

116 

117 

118class FakeProducer(RedisFastProducer): 

119 def __init__(self, broker: RedisBroker, config: ParserConfig) -> None: 

120 self.broker = broker 

121 

122 default = RedisPubSubParser(config) 

123 

124 self._parser = ParserComposition( 

125 broker._parser, 

126 default.parse_message, 

127 ) 

128 self._decoder = ParserComposition( 

129 broker._decoder, 

130 default.decode_message, 

131 ) 

132 

133 @override 

134 async def publish(self, cmd: "RedisPublishCommand") -> int | bytes: 

135 body = build_message( 

136 message=cmd.body, 

137 reply_to=cmd.reply_to, 

138 correlation_id=cmd.correlation_id or gen_cor_id(), 

139 headers=cmd.headers, 

140 message_format=cmd.message_format, 

141 serializer=self.broker.config.fd_config._serializer, 

142 ) 

143 

144 destination = _make_destination_kwargs(cmd) 

145 visitors = (ChannelVisitor(), ListVisitor(), StreamVisitor()) 

146 

147 for handler in self.broker.subscribers: # pragma: no branch 

148 handler = cast("LogicSubscriber", handler) 

149 for visitor in visitors: 

150 if visited_ch := visitor.visit(**destination, sub=handler): 

151 msg = visitor.get_message( 

152 visited_ch, 

153 body, 

154 handler, # type: ignore[arg-type] 

155 ) 

156 

157 await self._execute_handler(msg, handler) 

158 

159 return 0 

160 

161 @override 

162 async def request(self, cmd: "RedisPublishCommand") -> "PubSubMessage": 

163 body = build_message( 

164 message=cmd.body, 

165 correlation_id=cmd.correlation_id or gen_cor_id(), 

166 headers=cmd.headers, 

167 message_format=cmd.message_format, 

168 serializer=self.broker.config.fd_config._serializer, 

169 ) 

170 

171 destination = _make_destination_kwargs(cmd) 

172 visitors = (ChannelVisitor(), ListVisitor(), StreamVisitor()) 

173 

174 for handler in self.broker.subscribers: # pragma: no branch 

175 handler = cast("LogicSubscriber", handler) 

176 for visitor in visitors: 

177 if visited_ch := visitor.visit(**destination, sub=handler): 

178 msg = visitor.get_message( 

179 visited_ch, 

180 body, 

181 handler, # type: ignore[arg-type] 

182 ) 

183 

184 with anyio.fail_after(cmd.timeout): 

185 return await self._execute_handler(msg, handler) 

186 

187 raise SubscriberNotFound 

188 

189 @override 

190 async def publish_batch(self, cmd: "RedisPublishCommand") -> int: 

191 data_to_send = [ 

192 build_message( 

193 m, 

194 correlation_id=cmd.correlation_id or gen_cor_id(), 

195 headers=cmd.headers, 

196 message_format=cmd.message_format, 

197 serializer=self.broker.config.fd_config._serializer, 

198 ) 

199 for m in cmd.batch_bodies 

200 ] 

201 

202 visitor = ListVisitor() 

203 for handler in self.broker.subscribers: # pragma: no branch 

204 handler = cast("LogicSubscriber", handler) 

205 if visitor.visit(list=cmd.destination, sub=handler): 

206 casted_handler = cast("_ListHandlerMixin", handler) 

207 

208 if casted_handler.list_sub.batch: 208 ↛ 203line 208 didn't jump to line 203 because the condition on line 208 was always true

209 msg = visitor.get_message( 

210 channel=cmd.destination, 

211 body=data_to_send, 

212 sub=casted_handler, 

213 ) 

214 

215 await self._execute_handler(msg, handler) 

216 

217 return 0 

218 

219 async def _execute_handler( 

220 self, 

221 msg: Any, 

222 handler: "LogicSubscriber", 

223 ) -> "PubSubMessage": 

224 result = await handler.process_message(msg) 

225 

226 return PubSubMessage( 

227 type="message", 

228 data=build_message( 

229 message=result.body, 

230 headers=result.headers, 

231 correlation_id=result.correlation_id or "", 

232 message_format=handler.config.message_format, 

233 serializer=self.broker.config.fd_config._serializer, 

234 ), 

235 channel="", 

236 pattern=None, 

237 ) 

238 

239 

240def build_message( 

241 message: Union[Sequence["SendableMessage"], "SendableMessage"], 

242 *, 

243 correlation_id: str, 

244 message_format: type["MessageFormat"], 

245 reply_to: str = "", 

246 headers: dict[str, Any] | None = None, 

247 serializer: Optional["SerializerProto"] = None, 

248) -> bytes: 

249 return message_format.encode( 

250 message=message, 

251 reply_to=reply_to, 

252 headers=headers, 

253 correlation_id=correlation_id, 

254 serializer=serializer, 

255 ) 

256 

257 

258class Visitor(Protocol): 

259 def visit( 

260 self, 

261 *, 

262 channel: str | None, 

263 list: str | None, 

264 stream: str | None, 

265 sub: "LogicSubscriber", 

266 ) -> str | None: ... 

267 

268 def get_message(self, channel: str, body: Any, sub: "LogicSubscriber") -> Any: ... 

269 

270 

271class ChannelVisitor(Visitor): 

272 def visit( 

273 self, 

274 *, 

275 sub: "LogicSubscriber", 

276 channel: str | None = None, 

277 list: str | None = None, 

278 stream: str | None = None, 

279 ) -> str | None: 

280 if channel is None or not isinstance(sub, ChannelSubscriber): 

281 return None 

282 

283 sub_channel = sub.channel 

284 

285 if ( 

286 sub_channel.pattern 

287 and bool( 

288 re.match( 

289 sub_channel.name.replace(".", "\\.").replace("*", ".*"), 

290 channel or "", 

291 ), 

292 ) 

293 ) or channel == sub_channel.name: 

294 return channel 

295 

296 return None 

297 

298 def get_message( # type: ignore[override] 

299 self, 

300 channel: str, 

301 body: Any, 

302 sub: "ChannelSubscriber", 

303 ) -> Any: 

304 return PubSubMessage( 

305 type="message", 

306 data=body, 

307 channel=channel, 

308 pattern=sub.channel.pattern.encode() if sub.channel.pattern else None, 

309 ) 

310 

311 

312class ListVisitor(Visitor): 

313 def visit( 

314 self, 

315 *, 

316 sub: "LogicSubscriber", 

317 channel: str | None = None, 

318 list: str | None = None, 

319 stream: str | None = None, 

320 ) -> str | None: 

321 if list is None or not isinstance(sub, _ListHandlerMixin): 

322 return None 

323 

324 if list == sub.list_sub.name: 

325 return list 

326 

327 return None 

328 

329 def get_message( # type: ignore[override] 

330 self, 

331 channel: str, 

332 body: Any, 

333 sub: "_ListHandlerMixin", 

334 ) -> Any: 

335 if sub.list_sub.batch: 

336 return BatchListMessage( 

337 type="blist", 

338 channel=channel, 

339 data=body if isinstance(body, list) else [body], 

340 ) 

341 

342 return DefaultListMessage( 

343 type="list", 

344 channel=channel, 

345 data=body, 

346 ) 

347 

348 

349class StreamVisitor(Visitor): 

350 def visit( 

351 self, 

352 *, 

353 sub: "LogicSubscriber", 

354 channel: str | None = None, 

355 list: str | None = None, 

356 stream: str | None = None, 

357 ) -> str | None: 

358 if stream is None or not isinstance(sub, _StreamHandlerMixin): 

359 return None 

360 

361 if stream == sub.stream_sub.name: 

362 return stream 

363 

364 return None 

365 

366 def get_message( # type: ignore[override] 

367 self, 

368 channel: str, 

369 body: Any, 

370 sub: "_StreamHandlerMixin", 

371 ) -> Any: 

372 if sub.stream_sub.batch: 

373 return BatchStreamMessage( 

374 type="bstream", 

375 channel=channel, 

376 data=[{bDATA_KEY: body}], 

377 message_ids=[], 

378 ) 

379 

380 return DefaultStreamMessage( 

381 type="stream", 

382 channel=channel, 

383 data={bDATA_KEY: body}, 

384 message_ids=[], 

385 ) 

386 

387 

388class _DestinationKwargs(TypedDict, total=False): 

389 channel: str 

390 list: str 

391 stream: str 

392 

393 

394def _make_destination_kwargs(cmd: RedisPublishCommand) -> _DestinationKwargs: 

395 destination: _DestinationKwargs = {} 

396 if cmd.destination_type is DestinationType.Channel: 

397 destination["channel"] = cmd.destination 

398 if cmd.destination_type is DestinationType.List: 

399 destination["list"] = cmd.destination 

400 if cmd.destination_type is DestinationType.Stream: 

401 destination["stream"] = cmd.destination 

402 

403 if len(destination) != 1: 403 ↛ 404line 403 didn't jump to line 404 because the condition on line 403 was never true

404 raise SetupError(INCORRECT_SETUP_MSG) 

405 

406 return destination