Coverage for faststream / redis / subscriber / usecases / stream_subscriber.py: 86%

128 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-08 01:48 +0000

1import asyncio 

2import logging 

3import math 

4from collections.abc import AsyncIterator, Awaitable, Callable 

5from typing import TYPE_CHECKING, Any, Optional, TypeAlias 

6 

7from redis.exceptions import ResponseError 

8from typing_extensions import override 

9 

10from faststream._internal.endpoint.subscriber.mixins import ConcurrentMixin 

11from faststream._internal.endpoint.utils import process_msg 

12from faststream.redis.exceptions import StreamGroupNotFoundError 

13from faststream.redis.message import ( 

14 BatchStreamMessage, 

15 DefaultStreamMessage, 

16 RedisStreamMessage, 

17) 

18from faststream.redis.parser import ( 

19 RedisBatchStreamParser, 

20 RedisStreamParser, 

21) 

22 

23from .basic import LogicSubscriber 

24 

25if TYPE_CHECKING: 

26 from anyio import Event 

27 

28 from faststream._internal.endpoint.subscriber import SubscriberSpecification 

29 from faststream._internal.endpoint.subscriber.call_item import ( 

30 CallsCollection, 

31 ) 

32 from faststream.message import StreamMessage as BrokerStreamMessage 

33 from faststream.redis.schemas import StreamSub 

34 from faststream.redis.subscriber.config import RedisSubscriberConfig 

35 

36 

37TopicName: TypeAlias = bytes 

38Offset: TypeAlias = bytes 

39 

40ReadResponse = tuple[ 

41 tuple[ 

42 TopicName, 

43 tuple[ 

44 tuple[ 

45 Offset, 

46 dict[bytes, bytes], 

47 ], 

48 ..., 

49 ], 

50 ], 

51 ..., 

52] 

53ReadCallable = Callable[[str], Awaitable[ReadResponse]] 

54 

55 

56class _StreamHandlerMixin(LogicSubscriber): 

57 def __init__( 

58 self, 

59 config: "RedisSubscriberConfig", 

60 specification: "SubscriberSpecification[Any, Any]", 

61 calls: "CallsCollection[Any]", 

62 ) -> None: 

63 super().__init__(config, specification, calls) 

64 

65 assert config.stream_sub 

66 self._stream_sub = config.stream_sub 

67 self.last_id = config.stream_sub.last_id 

68 self.min_idle_time = config.stream_sub.min_idle_time 

69 self.autoclaim_start_id = b"0-0" 

70 

71 @property 

72 def stream_sub(self) -> "StreamSub": 

73 return self._stream_sub.add_prefix(self._outer_config.prefix) 

74 

75 def get_log_context( 

76 self, 

77 message: Optional["BrokerStreamMessage[Any]"], 

78 ) -> dict[str, str]: 

79 return self.build_log_context( 

80 message=message, 

81 channel=self.stream_sub.name, 

82 ) 

83 

84 @override 

85 async def _consume(self, *args: Any, start_signal: "Event") -> None: 

86 if await self._client.ping(): 86 ↛ 89line 86 didn't jump to line 89 because the condition on line 86 was always true

87 start_signal.set() 

88 

89 while self.running: 

90 try: 

91 await self._get_msgs(*args) 

92 

93 except ResponseError as e: # noqa: PERF203 

94 if "NOGROUP" in str(e): 

95 msg = ( 

96 f"Consumer group `{self.stream_sub.group}` for stream " 

97 f"`{self.stream_sub.name}` no longer exists. " 

98 "The stream was likely deleted or flushed. " 

99 "Stopping subscriber — restart the application to recreate the group." 

100 ) 

101 raise StreamGroupNotFoundError(msg) from e 

102 raise 

103 

104 except Exception as e: 

105 self._log( 

106 log_level=logging.ERROR, 

107 message="Message fetch error", 

108 exc_info=e, 

109 ) 

110 

111 finally: 

112 if not start_signal.is_set(): 112 ↛ 113line 112 didn't jump to line 113 because the condition on line 112 was never true

113 start_signal.set() 

114 

115 @override 

116 async def start(self) -> None: 

117 client = self._client 

118 

119 self.extra_watcher_options.update( 

120 redis=client, 

121 group=self.stream_sub.group, 

122 ) 

123 

124 stream = self.stream_sub 

125 

126 read: ReadCallable 

127 

128 if stream.group and stream.consumer: 

129 group_create_id = "$" if self.last_id == ">" else self.last_id 

130 try: 

131 await client.xgroup_create( 

132 name=stream.name, 

133 id=group_create_id, 

134 groupname=stream.group, 

135 mkstream=True, 

136 ) 

137 except ResponseError as e: 

138 if "already exists" not in str(e): 

139 raise 

140 

141 if stream.min_idle_time is None: 

142 

143 def read( 

144 _: str, 

145 ) -> Awaitable[ReadResponse]: 

146 return client.xreadgroup( 

147 groupname=stream.group, 

148 consumername=stream.consumer, 

149 streams={stream.name: stream.last_id}, 

150 count=stream.max_records, 

151 block=stream.polling_interval, 

152 noack=stream.no_ack, 

153 ) 

154 

155 else: 

156 

157 async def read(_: str) -> ReadResponse: 

158 stream_message = await client.xautoclaim( 

159 name=self.stream_sub.name, 

160 groupname=self.stream_sub.group, 

161 consumername=self.stream_sub.consumer, 

162 min_idle_time=self.min_idle_time, 

163 start_id=self.autoclaim_start_id, 

164 count=1, 

165 ) 

166 stream_name = self.stream_sub.name.encode() 

167 (next_id, messages, *_) = stream_message 

168 

169 # Update start_id for next call 

170 self.autoclaim_start_id = next_id 

171 

172 if next_id == b"0-0" and not messages: 

173 await asyncio.sleep(stream.polling_interval / 1000) # ms to s 

174 return () 

175 

176 return ((stream_name, messages),) 

177 

178 else: 

179 

180 def read( 

181 last_id: str, 

182 ) -> Awaitable[ReadResponse]: 

183 return client.xread( 

184 {stream.name: last_id}, 

185 block=stream.polling_interval, 

186 count=stream.max_records, 

187 ) 

188 

189 await super().start(read) 

190 

191 @override 

192 async def get_one( 

193 self, 

194 *, 

195 timeout: float = 5.0, 

196 ) -> "RedisStreamMessage | None": 

197 assert not self.calls, ( 

198 "You can't use `get_one` method if subscriber has registered handlers." 

199 ) 

200 if self.stream_sub.group and self.stream_sub.consumer: 

201 if self.min_idle_time is None: 201 ↛ 202line 201 didn't jump to line 202 because the condition on line 201 was never true

202 stream_message = await self._client.xreadgroup( 

203 groupname=self.stream_sub.group, 

204 consumername=self.stream_sub.consumer, 

205 streams={self.stream_sub.name: self.last_id}, 

206 block=math.ceil(timeout * 1000), 

207 count=1, 

208 ) 

209 if not stream_message: 

210 return None 

211 

212 ((stream_name, ((message_id, raw_message),)),) = stream_message 

213 else: 

214 stream_message = await self._client.xautoclaim( 

215 name=self.stream_sub.name, 

216 groupname=self.stream_sub.group, 

217 consumername=self.stream_sub.consumer, 

218 min_idle_time=self.min_idle_time, 

219 start_id=self.autoclaim_start_id, 

220 count=1, 

221 ) 

222 (next_id, messages, *_) = stream_message 

223 # Update start_id for next call 

224 self.autoclaim_start_id = next_id 

225 if not messages: 

226 return None 

227 stream_name = self.stream_sub.name.encode() 

228 ((message_id, raw_message),) = messages 

229 else: 

230 stream_message = await self._client.xread( 

231 {self.stream_sub.name: self.last_id}, 

232 block=math.ceil(timeout * 1000), 

233 count=1, 

234 ) 

235 if not stream_message: 

236 return None 

237 

238 ((stream_name, ((message_id, raw_message),)),) = stream_message 

239 

240 self.last_id = message_id.decode() 

241 

242 redis_incoming_msg = DefaultStreamMessage( 

243 type="stream", 

244 channel=stream_name.decode(), 

245 message_ids=[message_id], 

246 data=raw_message, 

247 ) 

248 

249 context = self._outer_config.fd_config.context 

250 async_parser, async_decoder = self._get_parser_and_decoder() 

251 

252 msg: RedisStreamMessage = await process_msg( # type: ignore[assignment] 

253 msg=redis_incoming_msg, 

254 middlewares=( 

255 m(redis_incoming_msg, context=context) for m in self._broker_middlewares 

256 ), 

257 parser=async_parser, 

258 decoder=async_decoder, 

259 ) 

260 return msg 

261 

262 @override 

263 async def __aiter__(self) -> AsyncIterator["RedisStreamMessage"]: # type: ignore[override] 

264 assert not self.calls, ( 

265 "You can't use iterator if subscriber has registered handlers." 

266 ) 

267 

268 timeout = 5 

269 

270 context = self._outer_config.fd_config.context 

271 async_parser, async_decoder = self._get_parser_and_decoder() 

272 

273 while True: 

274 if self.stream_sub.group and self.stream_sub.consumer: 

275 if self.min_idle_time is None: 275 ↛ 276line 275 didn't jump to line 276 because the condition on line 275 was never true

276 stream_message = await self._client.xreadgroup( 

277 groupname=self.stream_sub.group, 

278 consumername=self.stream_sub.consumer, 

279 streams={self.stream_sub.name: self.last_id}, 

280 block=math.ceil(timeout * 1000), 

281 count=1, 

282 ) 

283 if not stream_message: 

284 continue 

285 

286 ((stream_name, ((message_id, raw_message),)),) = stream_message 

287 else: 

288 stream_message = await self._client.xautoclaim( 

289 name=self.stream_sub.name, 

290 groupname=self.stream_sub.group, 

291 consumername=self.stream_sub.consumer, 

292 min_idle_time=self.min_idle_time, 

293 start_id=self.autoclaim_start_id, 

294 count=1, 

295 ) 

296 (next_id, messages, *_) = stream_message 

297 # Update start_id for next call 

298 self.autoclaim_start_id = next_id 

299 if not messages: 299 ↛ 300line 299 didn't jump to line 300 because the condition on line 299 was never true

300 continue 

301 stream_name = self.stream_sub.name.encode() 

302 ((message_id, raw_message),) = messages 

303 else: 

304 stream_message = await self._client.xread( 

305 {self.stream_sub.name: self.last_id}, 

306 block=math.ceil(timeout * 1000), 

307 count=1, 

308 ) 

309 if not stream_message: 309 ↛ 310line 309 didn't jump to line 310 because the condition on line 309 was never true

310 continue 

311 

312 ((stream_name, ((message_id, raw_message),)),) = stream_message 

313 

314 self.last_id = message_id.decode() 

315 

316 redis_incoming_msg = DefaultStreamMessage( 

317 type="stream", 

318 channel=stream_name.decode(), 

319 message_ids=[message_id], 

320 data=raw_message, 

321 ) 

322 

323 msg: RedisStreamMessage = await process_msg( # type: ignore[assignment] 

324 msg=redis_incoming_msg, 

325 middlewares=( 

326 m(redis_incoming_msg, context=context) 

327 for m in self._broker_middlewares 

328 ), 

329 parser=async_parser, 

330 decoder=async_decoder, 

331 ) 

332 yield msg 

333 

334 

335class StreamSubscriber(_StreamHandlerMixin): 

336 def __init__( 

337 self, 

338 config: "RedisSubscriberConfig", 

339 specification: "SubscriberSpecification[Any, Any]", 

340 calls: "CallsCollection[Any]", 

341 ) -> None: 

342 parser = RedisStreamParser(config) 

343 config.decoder = parser.decode_message 

344 config.parser = parser.parse_message 

345 super().__init__(config, specification, calls) 

346 

347 async def _get_msgs( 

348 self, 

349 read: Callable[ 

350 [str], 

351 Awaitable[ 

352 tuple[ 

353 tuple[ 

354 TopicName, 

355 tuple[ 

356 tuple[ 

357 Offset, 

358 dict[bytes, bytes], 

359 ], 

360 ..., 

361 ], 

362 ], 

363 ..., 

364 ], 

365 ], 

366 ], 

367 ) -> None: 

368 for stream_name, msgs in await read(self.last_id): 

369 if msgs: 

370 self.last_id = msgs[-1][0].decode() 

371 

372 for message_id, raw_msg in msgs: 

373 msg = DefaultStreamMessage( 

374 type="stream", 

375 channel=stream_name.decode(), 

376 message_ids=[message_id], 

377 data=raw_msg, 

378 ) 

379 

380 await self.consume_one(msg) 

381 

382 

383class StreamBatchSubscriber(_StreamHandlerMixin): 

384 def __init__( 

385 self, 

386 config: "RedisSubscriberConfig", 

387 specification: "SubscriberSpecification[Any, Any]", 

388 calls: "CallsCollection[Any]", 

389 ) -> None: 

390 parser = RedisBatchStreamParser(config) 

391 config.decoder = parser.decode_message 

392 config.parser = parser.parse_message 

393 super().__init__(config, specification, calls) 

394 

395 async def _get_msgs( 

396 self, 

397 read: Callable[ 

398 [str], 

399 Awaitable[ 

400 tuple[tuple[bytes, tuple[tuple[bytes, dict[bytes, bytes]], ...]], ...], 

401 ], 

402 ], 

403 ) -> None: 

404 for stream_name, msgs in await read(self.last_id): 

405 if msgs: 

406 self.last_id = msgs[-1][0].decode() 

407 

408 data: list[dict[bytes, bytes]] = [] 

409 ids: list[bytes] = [] 

410 for message_id, i in msgs: 

411 data.append(i) 

412 ids.append(message_id) 

413 

414 msg = BatchStreamMessage( 

415 type="bstream", 

416 channel=stream_name.decode(), 

417 data=data, 

418 message_ids=ids, 

419 ) 

420 

421 await self.consume_one(msg) 

422 

423 

424class StreamConcurrentSubscriber( 

425 ConcurrentMixin["BrokerStreamMessage[Any]"], 

426 StreamSubscriber, 

427): 

428 async def start(self) -> None: 

429 await super().start() 

430 self.start_consume_task() 

431 

432 async def consume_one(self, msg: "BrokerStreamMessage[Any]") -> None: 

433 await self._put_msg(msg)