Coverage for faststream / redis / subscriber / usecases / channel_subscriber.py: 96%

75 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-08 01:48 +0000

1from collections.abc import AsyncIterator 

2from typing import TYPE_CHECKING, Any, Optional, TypeAlias 

3 

4import anyio 

5from redis.asyncio.client import ( 

6 PubSub as RPubSub, 

7) 

8from typing_extensions import override 

9 

10from faststream._internal.endpoint.subscriber.mixins import ConcurrentMixin 

11from faststream._internal.endpoint.utils import process_msg 

12from faststream.redis.message import ( 

13 PubSubMessage, 

14 RedisChannelMessage, 

15) 

16from faststream.redis.parser import ( 

17 RedisPubSubParser, 

18) 

19 

20from .basic import LogicSubscriber 

21 

22if TYPE_CHECKING: 

23 from faststream._internal.endpoint.subscriber import SubscriberSpecification 

24 from faststream._internal.endpoint.subscriber.call_item import ( 

25 CallsCollection, 

26 ) 

27 from faststream.message import StreamMessage as BrokerStreamMessage 

28 from faststream.redis.schemas import PubSub 

29 from faststream.redis.subscriber.config import RedisSubscriberConfig 

30 

31 

32TopicName: TypeAlias = bytes 

33Offset: TypeAlias = bytes 

34 

35 

36class ChannelSubscriber(LogicSubscriber): 

37 def __init__( 

38 self, 

39 config: "RedisSubscriberConfig", 

40 specification: "SubscriberSpecification[Any, Any]", 

41 calls: "CallsCollection[Any]", 

42 ) -> None: 

43 assert config.channel_sub 

44 parser = RedisPubSubParser(config, pattern=config.channel_sub.path_regex) 

45 config.decoder = parser.decode_message 

46 config.parser = parser.parse_message 

47 super().__init__(config, specification, calls) 

48 

49 self._channel = config.channel_sub 

50 self.subscription: RPubSub | None = None 

51 

52 @property 

53 def channel(self) -> "PubSub": 

54 return self._channel.add_prefix(self._outer_config.prefix) 

55 

56 def get_log_context( 

57 self, 

58 message: Optional["BrokerStreamMessage[Any]"], 

59 ) -> dict[str, str]: 

60 return self.build_log_context( 

61 message=message, 

62 channel=self.channel.name, 

63 ) 

64 

65 @override 

66 async def start(self) -> None: 

67 if self.subscription: 67 ↛ 68line 67 didn't jump to line 68 because the condition on line 67 was never true

68 return 

69 

70 self.subscription = psub = self._client.pubsub() 

71 

72 if self.channel.pattern: 

73 await psub.psubscribe(self.channel.name) 

74 else: 

75 await psub.subscribe(self.channel.name) 

76 

77 await super().start(psub) 

78 

79 async def stop(self) -> None: 

80 await super().stop() 

81 

82 if self.subscription is not None: 

83 await self.subscription.unsubscribe() 

84 await self.subscription.aclose() # type: ignore[attr-defined] 

85 self.subscription = None 

86 

87 @override 

88 async def get_one( 

89 self, 

90 *, 

91 timeout: float = 5.0, 

92 ) -> "RedisChannelMessage | None": 

93 assert self.subscription, "You should start subscriber at first." 

94 assert not self.calls, ( 

95 "You can't use `get_one` method if subscriber has registered handlers." 

96 ) 

97 

98 sleep_interval = timeout / 10 

99 

100 raw_message: PubSubMessage | None = None 

101 

102 with anyio.move_on_after(timeout): 

103 while (raw_message := await self._get_message(self.subscription)) is None: # noqa: ASYNC110 

104 await anyio.sleep(sleep_interval) 

105 

106 context = self._outer_config.fd_config.context 

107 async_parser, async_decoder = self._get_parser_and_decoder() 

108 

109 msg: RedisChannelMessage | None = await process_msg( # type: ignore[assignment] 

110 msg=raw_message, 

111 middlewares=( 

112 m(raw_message, context=context) for m in self._broker_middlewares 

113 ), 

114 parser=async_parser, 

115 decoder=async_decoder, 

116 ) 

117 return msg 

118 

119 @override 

120 async def __aiter__(self) -> AsyncIterator["RedisChannelMessage"]: # type: ignore[override] 

121 assert self.subscription, "You should start subscriber at first." 

122 assert not self.calls, ( 

123 "You can't use iterator if subscriber has registered handlers." 

124 ) 

125 

126 timeout = 5 

127 sleep_interval = timeout / 10 

128 

129 raw_message: PubSubMessage | None = None 

130 

131 context = self._outer_config.fd_config.context 

132 async_parser, async_decoder = self._get_parser_and_decoder() 

133 

134 while True: 

135 with anyio.move_on_after(timeout): 

136 while ( # noqa: ASYNC110 

137 raw_message := await self._get_message(self.subscription) 

138 ) is None: 

139 await anyio.sleep(sleep_interval) 

140 

141 if raw_message is None: 141 ↛ 142line 141 didn't jump to line 142 because the condition on line 141 was never true

142 continue 

143 

144 msg: RedisChannelMessage = await process_msg( # type: ignore[assignment] 

145 msg=raw_message, 

146 middlewares=( 

147 m(raw_message, context=context) for m in self._broker_middlewares 

148 ), 

149 parser=async_parser, 

150 decoder=async_decoder, 

151 ) 

152 yield msg 

153 

154 async def _get_message(self, psub: RPubSub) -> PubSubMessage | None: 

155 raw_msg = await psub.get_message( 

156 ignore_subscribe_messages=True, 

157 timeout=self.channel.polling_interval, 

158 ) 

159 

160 if raw_msg: 

161 return PubSubMessage( 

162 type=raw_msg["type"], 

163 data=raw_msg["data"], 

164 channel=raw_msg["channel"].decode(), 

165 pattern=raw_msg["pattern"], 

166 ) 

167 

168 return None 

169 

170 async def _get_msgs(self, psub: RPubSub) -> None: 

171 if msg := await self._get_message(psub): 

172 await self.consume_one(msg) 

173 

174 

175class ChannelConcurrentSubscriber( 

176 ConcurrentMixin["BrokerStreamMessage[Any]"], 

177 ChannelSubscriber, 

178): 

179 async def start(self) -> None: 

180 await super().start() 

181 self.start_consume_task() 

182 

183 async def consume_one(self, msg: "BrokerStreamMessage[Any]") -> None: 

184 await self._put_msg(msg)