Coverage for faststream / redis / subscriber / usecases / stream_subscriber.py: 86%
128 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-08 01:48 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-08 01:48 +0000
1import asyncio
2import logging
3import math
4from collections.abc import AsyncIterator, Awaitable, Callable
5from typing import TYPE_CHECKING, Any, Optional, TypeAlias
7from redis.exceptions import ResponseError
8from typing_extensions import override
10from faststream._internal.endpoint.subscriber.mixins import ConcurrentMixin
11from faststream._internal.endpoint.utils import process_msg
12from faststream.redis.exceptions import StreamGroupNotFoundError
13from faststream.redis.message import (
14 BatchStreamMessage,
15 DefaultStreamMessage,
16 RedisStreamMessage,
17)
18from faststream.redis.parser import (
19 RedisBatchStreamParser,
20 RedisStreamParser,
21)
23from .basic import LogicSubscriber
25if TYPE_CHECKING:
26 from anyio import Event
28 from faststream._internal.endpoint.subscriber import SubscriberSpecification
29 from faststream._internal.endpoint.subscriber.call_item import (
30 CallsCollection,
31 )
32 from faststream.message import StreamMessage as BrokerStreamMessage
33 from faststream.redis.schemas import StreamSub
34 from faststream.redis.subscriber.config import RedisSubscriberConfig
37TopicName: TypeAlias = bytes
38Offset: TypeAlias = bytes
40ReadResponse = tuple[
41 tuple[
42 TopicName,
43 tuple[
44 tuple[
45 Offset,
46 dict[bytes, bytes],
47 ],
48 ...,
49 ],
50 ],
51 ...,
52]
53ReadCallable = Callable[[str], Awaitable[ReadResponse]]
56class _StreamHandlerMixin(LogicSubscriber):
57 def __init__(
58 self,
59 config: "RedisSubscriberConfig",
60 specification: "SubscriberSpecification[Any, Any]",
61 calls: "CallsCollection[Any]",
62 ) -> None:
63 super().__init__(config, specification, calls)
65 assert config.stream_sub
66 self._stream_sub = config.stream_sub
67 self.last_id = config.stream_sub.last_id
68 self.min_idle_time = config.stream_sub.min_idle_time
69 self.autoclaim_start_id = b"0-0"
71 @property
72 def stream_sub(self) -> "StreamSub":
73 return self._stream_sub.add_prefix(self._outer_config.prefix)
75 def get_log_context(
76 self,
77 message: Optional["BrokerStreamMessage[Any]"],
78 ) -> dict[str, str]:
79 return self.build_log_context(
80 message=message,
81 channel=self.stream_sub.name,
82 )
84 @override
85 async def _consume(self, *args: Any, start_signal: "Event") -> None:
86 if await self._client.ping(): 86 ↛ 89line 86 didn't jump to line 89 because the condition on line 86 was always true
87 start_signal.set()
89 while self.running:
90 try:
91 await self._get_msgs(*args)
93 except ResponseError as e: # noqa: PERF203
94 if "NOGROUP" in str(e):
95 msg = (
96 f"Consumer group `{self.stream_sub.group}` for stream "
97 f"`{self.stream_sub.name}` no longer exists. "
98 "The stream was likely deleted or flushed. "
99 "Stopping subscriber — restart the application to recreate the group."
100 )
101 raise StreamGroupNotFoundError(msg) from e
102 raise
104 except Exception as e:
105 self._log(
106 log_level=logging.ERROR,
107 message="Message fetch error",
108 exc_info=e,
109 )
111 finally:
112 if not start_signal.is_set(): 112 ↛ 113line 112 didn't jump to line 113 because the condition on line 112 was never true
113 start_signal.set()
115 @override
116 async def start(self) -> None:
117 client = self._client
119 self.extra_watcher_options.update(
120 redis=client,
121 group=self.stream_sub.group,
122 )
124 stream = self.stream_sub
126 read: ReadCallable
128 if stream.group and stream.consumer:
129 group_create_id = "$" if self.last_id == ">" else self.last_id
130 try:
131 await client.xgroup_create(
132 name=stream.name,
133 id=group_create_id,
134 groupname=stream.group,
135 mkstream=True,
136 )
137 except ResponseError as e:
138 if "already exists" not in str(e):
139 raise
141 if stream.min_idle_time is None:
143 def read(
144 _: str,
145 ) -> Awaitable[ReadResponse]:
146 return client.xreadgroup(
147 groupname=stream.group,
148 consumername=stream.consumer,
149 streams={stream.name: stream.last_id},
150 count=stream.max_records,
151 block=stream.polling_interval,
152 noack=stream.no_ack,
153 )
155 else:
157 async def read(_: str) -> ReadResponse:
158 stream_message = await client.xautoclaim(
159 name=self.stream_sub.name,
160 groupname=self.stream_sub.group,
161 consumername=self.stream_sub.consumer,
162 min_idle_time=self.min_idle_time,
163 start_id=self.autoclaim_start_id,
164 count=1,
165 )
166 stream_name = self.stream_sub.name.encode()
167 (next_id, messages, *_) = stream_message
169 # Update start_id for next call
170 self.autoclaim_start_id = next_id
172 if next_id == b"0-0" and not messages:
173 await asyncio.sleep(stream.polling_interval / 1000) # ms to s
174 return ()
176 return ((stream_name, messages),)
178 else:
180 def read(
181 last_id: str,
182 ) -> Awaitable[ReadResponse]:
183 return client.xread(
184 {stream.name: last_id},
185 block=stream.polling_interval,
186 count=stream.max_records,
187 )
189 await super().start(read)
191 @override
192 async def get_one(
193 self,
194 *,
195 timeout: float = 5.0,
196 ) -> "RedisStreamMessage | None":
197 assert not self.calls, (
198 "You can't use `get_one` method if subscriber has registered handlers."
199 )
200 if self.stream_sub.group and self.stream_sub.consumer:
201 if self.min_idle_time is None: 201 ↛ 202line 201 didn't jump to line 202 because the condition on line 201 was never true
202 stream_message = await self._client.xreadgroup(
203 groupname=self.stream_sub.group,
204 consumername=self.stream_sub.consumer,
205 streams={self.stream_sub.name: self.last_id},
206 block=math.ceil(timeout * 1000),
207 count=1,
208 )
209 if not stream_message:
210 return None
212 ((stream_name, ((message_id, raw_message),)),) = stream_message
213 else:
214 stream_message = await self._client.xautoclaim(
215 name=self.stream_sub.name,
216 groupname=self.stream_sub.group,
217 consumername=self.stream_sub.consumer,
218 min_idle_time=self.min_idle_time,
219 start_id=self.autoclaim_start_id,
220 count=1,
221 )
222 (next_id, messages, *_) = stream_message
223 # Update start_id for next call
224 self.autoclaim_start_id = next_id
225 if not messages:
226 return None
227 stream_name = self.stream_sub.name.encode()
228 ((message_id, raw_message),) = messages
229 else:
230 stream_message = await self._client.xread(
231 {self.stream_sub.name: self.last_id},
232 block=math.ceil(timeout * 1000),
233 count=1,
234 )
235 if not stream_message:
236 return None
238 ((stream_name, ((message_id, raw_message),)),) = stream_message
240 self.last_id = message_id.decode()
242 redis_incoming_msg = DefaultStreamMessage(
243 type="stream",
244 channel=stream_name.decode(),
245 message_ids=[message_id],
246 data=raw_message,
247 )
249 context = self._outer_config.fd_config.context
250 async_parser, async_decoder = self._get_parser_and_decoder()
252 msg: RedisStreamMessage = await process_msg( # type: ignore[assignment]
253 msg=redis_incoming_msg,
254 middlewares=(
255 m(redis_incoming_msg, context=context) for m in self._broker_middlewares
256 ),
257 parser=async_parser,
258 decoder=async_decoder,
259 )
260 return msg
262 @override
263 async def __aiter__(self) -> AsyncIterator["RedisStreamMessage"]: # type: ignore[override]
264 assert not self.calls, (
265 "You can't use iterator if subscriber has registered handlers."
266 )
268 timeout = 5
270 context = self._outer_config.fd_config.context
271 async_parser, async_decoder = self._get_parser_and_decoder()
273 while True:
274 if self.stream_sub.group and self.stream_sub.consumer:
275 if self.min_idle_time is None: 275 ↛ 276line 275 didn't jump to line 276 because the condition on line 275 was never true
276 stream_message = await self._client.xreadgroup(
277 groupname=self.stream_sub.group,
278 consumername=self.stream_sub.consumer,
279 streams={self.stream_sub.name: self.last_id},
280 block=math.ceil(timeout * 1000),
281 count=1,
282 )
283 if not stream_message:
284 continue
286 ((stream_name, ((message_id, raw_message),)),) = stream_message
287 else:
288 stream_message = await self._client.xautoclaim(
289 name=self.stream_sub.name,
290 groupname=self.stream_sub.group,
291 consumername=self.stream_sub.consumer,
292 min_idle_time=self.min_idle_time,
293 start_id=self.autoclaim_start_id,
294 count=1,
295 )
296 (next_id, messages, *_) = stream_message
297 # Update start_id for next call
298 self.autoclaim_start_id = next_id
299 if not messages: 299 ↛ 300line 299 didn't jump to line 300 because the condition on line 299 was never true
300 continue
301 stream_name = self.stream_sub.name.encode()
302 ((message_id, raw_message),) = messages
303 else:
304 stream_message = await self._client.xread(
305 {self.stream_sub.name: self.last_id},
306 block=math.ceil(timeout * 1000),
307 count=1,
308 )
309 if not stream_message: 309 ↛ 310line 309 didn't jump to line 310 because the condition on line 309 was never true
310 continue
312 ((stream_name, ((message_id, raw_message),)),) = stream_message
314 self.last_id = message_id.decode()
316 redis_incoming_msg = DefaultStreamMessage(
317 type="stream",
318 channel=stream_name.decode(),
319 message_ids=[message_id],
320 data=raw_message,
321 )
323 msg: RedisStreamMessage = await process_msg( # type: ignore[assignment]
324 msg=redis_incoming_msg,
325 middlewares=(
326 m(redis_incoming_msg, context=context)
327 for m in self._broker_middlewares
328 ),
329 parser=async_parser,
330 decoder=async_decoder,
331 )
332 yield msg
335class StreamSubscriber(_StreamHandlerMixin):
336 def __init__(
337 self,
338 config: "RedisSubscriberConfig",
339 specification: "SubscriberSpecification[Any, Any]",
340 calls: "CallsCollection[Any]",
341 ) -> None:
342 parser = RedisStreamParser(config)
343 config.decoder = parser.decode_message
344 config.parser = parser.parse_message
345 super().__init__(config, specification, calls)
347 async def _get_msgs(
348 self,
349 read: Callable[
350 [str],
351 Awaitable[
352 tuple[
353 tuple[
354 TopicName,
355 tuple[
356 tuple[
357 Offset,
358 dict[bytes, bytes],
359 ],
360 ...,
361 ],
362 ],
363 ...,
364 ],
365 ],
366 ],
367 ) -> None:
368 for stream_name, msgs in await read(self.last_id):
369 if msgs:
370 self.last_id = msgs[-1][0].decode()
372 for message_id, raw_msg in msgs:
373 msg = DefaultStreamMessage(
374 type="stream",
375 channel=stream_name.decode(),
376 message_ids=[message_id],
377 data=raw_msg,
378 )
380 await self.consume_one(msg)
383class StreamBatchSubscriber(_StreamHandlerMixin):
384 def __init__(
385 self,
386 config: "RedisSubscriberConfig",
387 specification: "SubscriberSpecification[Any, Any]",
388 calls: "CallsCollection[Any]",
389 ) -> None:
390 parser = RedisBatchStreamParser(config)
391 config.decoder = parser.decode_message
392 config.parser = parser.parse_message
393 super().__init__(config, specification, calls)
395 async def _get_msgs(
396 self,
397 read: Callable[
398 [str],
399 Awaitable[
400 tuple[tuple[bytes, tuple[tuple[bytes, dict[bytes, bytes]], ...]], ...],
401 ],
402 ],
403 ) -> None:
404 for stream_name, msgs in await read(self.last_id):
405 if msgs:
406 self.last_id = msgs[-1][0].decode()
408 data: list[dict[bytes, bytes]] = []
409 ids: list[bytes] = []
410 for message_id, i in msgs:
411 data.append(i)
412 ids.append(message_id)
414 msg = BatchStreamMessage(
415 type="bstream",
416 channel=stream_name.decode(),
417 data=data,
418 message_ids=ids,
419 )
421 await self.consume_one(msg)
424class StreamConcurrentSubscriber(
425 ConcurrentMixin["BrokerStreamMessage[Any]"],
426 StreamSubscriber,
427):
428 async def start(self) -> None:
429 await super().start()
430 self.start_consume_task()
432 async def consume_one(self, msg: "BrokerStreamMessage[Any]") -> None:
433 await self._put_msg(msg)