Coverage for faststream / nats / subscriber / usecases / stream_pull_subscriber.py: 82%

76 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-08 01:48 +0000

1from collections.abc import AsyncIterator, Awaitable, Callable 

2from contextlib import suppress 

3from typing import TYPE_CHECKING, Any, Optional, cast 

4 

5import anyio 

6from nats.errors import ConnectionClosedError, TimeoutError 

7from nats.js.errors import ServiceUnavailableError 

8from typing_extensions import override 

9 

10from faststream._internal.endpoint.subscriber.mixins import ConcurrentMixin, TasksMixin 

11from faststream._internal.endpoint.utils import process_msg 

12from faststream.nats.parser import ( 

13 BatchParser, 

14) 

15 

16from .basic import DefaultSubscriber 

17from .stream_basic import StreamSubscriber 

18 

19if TYPE_CHECKING: 

20 from nats.aio.msg import Msg 

21 from nats.js import JetStreamContext 

22 

23 from faststream._internal.basic_types import SendableMessage 

24 from faststream._internal.endpoint.subscriber import SubscriberSpecification 

25 from faststream._internal.endpoint.subscriber.call_item import CallsCollection 

26 from faststream.nats.message import NatsMessage 

27 from faststream.nats.schemas import JStream, PullSub 

28 from faststream.nats.subscriber.config import NatsSubscriberConfig 

29 

30 

31class PullStreamSubscriber( 

32 TasksMixin, 

33 StreamSubscriber, 

34): 

35 subscription: Optional["JetStreamContext.PullSubscription"] 

36 

37 def __init__( 

38 self, 

39 config: "NatsSubscriberConfig", 

40 specification: "SubscriberSpecification[Any, Any]", 

41 calls: "CallsCollection[Msg]", 

42 *, 

43 queue: str, 

44 pull_sub: "PullSub", 

45 stream: "JStream", 

46 ) -> None: 

47 super().__init__( 

48 config, 

49 specification, 

50 calls, 

51 # basic args 

52 queue=queue, 

53 stream=stream, 

54 ) 

55 

56 self.pull_sub = pull_sub 

57 

58 @override 

59 async def _create_subscription(self) -> None: 

60 """Create NATS subscription and start consume task.""" 

61 if self.subscription: 61 ↛ 62line 61 didn't jump to line 62 because the condition on line 61 was never true

62 return 

63 

64 self.subscription = await self.jetstream.pull_subscribe( 

65 subject=self.clear_subject, 

66 config=self.config, 

67 **self.extra_options, 

68 ) 

69 self.add_task(self._consume_pull, func_kwargs={"cb": self.consume}) 

70 

71 async def _consume_pull( 

72 self, 

73 cb: Callable[["Msg"], Awaitable["SendableMessage"]], 

74 ) -> None: 

75 """Endless task consuming messages using NATS Pull subscriber.""" 

76 assert self.subscription 

77 

78 while self.running: # pragma: no branch 

79 messages = [] 

80 with suppress(TimeoutError, ConnectionClosedError, ServiceUnavailableError): 

81 messages = await self.subscription.fetch( 

82 batch=self.pull_sub.batch_size, 

83 timeout=self.pull_sub.timeout, 

84 ) 

85 

86 if messages: 86 ↛ 78line 86 didn't jump to line 78 because the condition on line 86 was always true

87 async with anyio.create_task_group() as tg: 

88 for msg in messages: 

89 tg.start_soon(cb, msg) 

90 

91 

92class ConcurrentPullStreamSubscriber(ConcurrentMixin["Msg"], PullStreamSubscriber): 

93 @override 

94 async def _create_subscription(self) -> None: 

95 """Create NATS subscription and start consume task.""" 

96 if self.subscription: 

97 return 

98 

99 self.start_consume_task() 

100 

101 self.subscription = await self.jetstream.pull_subscribe( 

102 subject=self.clear_subject, 

103 config=self.config, 

104 **self.extra_options, 

105 ) 

106 self.add_task(self._consume_pull, func_kwargs={"cb": self._put_msg}) 

107 

108 

109class BatchPullStreamSubscriber( 

110 TasksMixin, 

111 DefaultSubscriber[list["Msg"]], 

112): 

113 """Batch-message consumer class.""" 

114 

115 subscription: Optional["JetStreamContext.PullSubscription"] 

116 _fetch_sub: Optional["JetStreamContext.PullSubscription"] 

117 

118 def __init__( 

119 self, 

120 config: "NatsSubscriberConfig", 

121 specification: "SubscriberSpecification[Any, Any]", 

122 calls: "CallsCollection[list[Msg]]", 

123 *, 

124 stream: "JStream", 

125 pull_sub: "PullSub", 

126 ) -> None: 

127 parser = BatchParser(pattern=config.subject) 

128 config.decoder = parser.decode_batch 

129 config.parser = parser.parse_batch 

130 super().__init__(config, specification, calls) 

131 

132 self.stream = stream 

133 self.pull_sub = pull_sub 

134 

135 @override 

136 async def get_one( 

137 self, 

138 *, 

139 timeout: float = 5, 

140 ) -> Optional["NatsMessage"]: 

141 assert not self.calls, ( 

142 "You can't use `get_one` method if subscriber has registered handlers." 

143 ) 

144 

145 if not self._fetch_sub: 145 ↛ 152line 145 didn't jump to line 152 because the condition on line 145 was always true

146 fetch_sub = self._fetch_sub = await self.jetstream.pull_subscribe( 

147 subject=self.clear_subject, 

148 config=self.config, 

149 **self.extra_options, 

150 ) 

151 else: 

152 fetch_sub = self._fetch_sub 

153 

154 try: 

155 raw_message = await fetch_sub.fetch( 

156 batch=1, 

157 timeout=timeout, 

158 ) 

159 except TimeoutError: 

160 return None 

161 

162 context = self._outer_config.fd_config.context 

163 async_parser, async_decoder = self._get_parser_and_decoder() 

164 

165 return cast( 

166 "NatsMessage", 

167 await process_msg( 

168 msg=raw_message, 

169 middlewares=( 

170 m(raw_message, context=context) for m in self._broker_middlewares 

171 ), 

172 parser=async_parser, 

173 decoder=async_decoder, 

174 ), 

175 ) 

176 

177 @override 

178 async def __aiter__(self) -> AsyncIterator["NatsMessage"]: # type: ignore[override] 

179 assert not self.calls, ( 

180 "You can't use iterator if subscriber has registered handlers." 

181 ) 

182 

183 if not self._fetch_sub: 183 ↛ 190line 183 didn't jump to line 190 because the condition on line 183 was always true

184 fetch_sub = self._fetch_sub = await self.jetstream.pull_subscribe( 

185 subject=self.clear_subject, 

186 config=self.config, 

187 **self.extra_options, 

188 ) 

189 else: 

190 fetch_sub = self._fetch_sub 

191 

192 context = self._outer_config.fd_config.context 

193 async_parser, async_decoder = self._get_parser_and_decoder() 

194 

195 while True: 

196 raw_message = await fetch_sub.fetch(batch=1) 

197 

198 yield cast( 

199 "NatsMessage", 

200 await process_msg( 

201 msg=raw_message, 

202 middlewares=( 

203 m(raw_message, context=context) for m in self._broker_middlewares 

204 ), 

205 parser=async_parser, 

206 decoder=async_decoder, 

207 ), 

208 ) 

209 

210 @override 

211 async def _create_subscription(self) -> None: 

212 """Create NATS subscription and start consume task.""" 

213 if self.subscription: 213 ↛ 214line 213 didn't jump to line 214 because the condition on line 213 was never true

214 return 

215 

216 self.subscription = await self.jetstream.pull_subscribe( 

217 subject=self.clear_subject, 

218 config=self.config, 

219 **self.extra_options, 

220 ) 

221 self.add_task(self._consume_pull) 

222 

223 async def _consume_pull(self) -> None: 

224 """Endless task consuming messages using NATS Pull subscriber.""" 

225 assert self.subscription, "You should call `create_subscription` at first." 

226 

227 while self.running: # pragma: no branch 

228 with suppress(TimeoutError, ConnectionClosedError, ServiceUnavailableError): 

229 messages = await self.subscription.fetch( 

230 batch=self.pull_sub.batch_size, 

231 timeout=self.pull_sub.timeout, 

232 ) 

233 

234 if messages: 234 ↛ 227line 234 didn't jump to line 227

235 await self.consume(messages)