Coverage for fastagency/adapters/nats/base.py: 17%

180 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-19 12:16 +0000

1## IONats part 

2 

3import asyncio 1abcdefghi

4import os 1abcdefghi

5import traceback 1abcdefghi

6from collections.abc import AsyncIterator, Iterator 1abcdefghi

7from contextlib import asynccontextmanager, contextmanager 1abcdefghi

8from queue import Queue 1abcdefghi

9from typing import TYPE_CHECKING, Any, Optional 1abcdefghi

10 

11from asyncer import asyncify, syncify 1abcdefghi

12from faststream import FastStream, Logger 1abcdefghi

13from faststream.nats import JStream, NatsBroker, NatsMessage 1abcdefghi

14from nats.aio.client import Client as NatsClient 1abcdefghi

15from nats.errors import NoServersError 1abcdefghi

16from nats.js import JetStreamContext, api 1abcdefghi

17from nats.js.errors import KeyNotFoundError, NoKeysError 1abcdefghi

18from nats.js.kv import KeyValue 1abcdefghi

19 

20from ...base import UI, CreateWorkflowUIMixin, ProviderProtocol, Runnable, UIBase 1abcdefghi

21from ...exceptions import FastAgencyNATSConnectionError, FastAgencyNATSKeyError 1abcdefghi

22from ...logging import get_logger 1abcdefghi

23from ...messages import ( 1abcdefghi

24 AskingMessage, 

25 IOMessage, 

26 InitiateWorkflowModel, 

27 InputResponseModel, 

28 MessageProcessorMixin, 

29 MultipleChoice, 

30 TextInput, 

31 TextMessage, 

32) 

33 

34if TYPE_CHECKING: 1abcdefghi

35 from faststream.nats.subscriber.asyncapi import AsyncAPISubscriber 

36 

37 

38logger = get_logger(__name__) 1abcdefghi

39 

40JETSTREAM = JStream( 1abcdefghi

41 name="FastAgency", 

42 subjects=[ 

43 # starts new conversation (can land on any worker) 

44 "chat.server.initiate_chat", 

45 # server requests input from client; chat.client.messages.<user_uuid>.<workflow_uuid> 

46 # we create this topic dynamically => client process consuming NATS can fix its worker 

47 "chat.client.messages.*.*", 

48 # server prints message to client; chat.server.messages.<user_uuid>.<workflow_uuid> 

49 # we create this topic dynamically and subscribe to it => worker is fixed 

50 "chat.server.messages.*.*", 

51 # discovery subject 

52 "discovery", 

53 ], 

54) 

55 

56 

57class NatsAdapter(MessageProcessorMixin, CreateWorkflowUIMixin): 1abcdefghi

58 def __init__( 1abcdefghi

59 self, 

60 provider: ProviderProtocol, 1abcdefghi

61 *, 

62 nats_url: Optional[str] = None, 1abcdefghi

63 user: Optional[str] = None, 1abcdefghi

64 password: Optional[str] = None, 1abcdefghi

65 super_conversation: Optional["NatsAdapter"] = None, 1abcdefghi

66 ) -> None: 1abcdefghi

67 """Provider for NATS. 

68 

69 Args: 

70 provider (ProviderProtocol): The provider. 

71 nats_url (Optional[str], optional): The NATS URL. Defaults to None in which case 'nats://localhost:4222' is used. 

72 user (Optional[str], optional): The user. Defaults to None. 

73 password (Optional[str], optional): The password. Defaults to None. 

74 super_conversation (Optional["NatsProvider"], optional): The super conversation. Defaults to None. 

75 """ 

76 self.provider = provider 

77 self.nats_url = nats_url or "nats://localhost:4222" 

78 self.user = user 

79 self.password = password 

80 self.queue: Queue = Queue() # type: ignore[type-arg] 

81 

82 self.broker = NatsBroker(self.nats_url, user=user, password=password) 

83 self.app = FastStream(self.broker) 

84 self.subscriber: "AsyncAPISubscriber" 

85 self._input_request_subject: str 

86 self._input_receive_subject: str 

87 

88 self.super_conversation: Optional[NatsAdapter] = super_conversation 

89 self.sub_conversations: list[NatsAdapter] = [] 

90 

91 self._create_initiate_subscriber() 

92 

93 async def _handle_input( 1abcdefghi

94 self, body: InputResponseModel, msg: NatsMessage, logger: Logger 

95 ) -> None: 

96 """Handle input from the client by consuming messages from chat.server.messages.*.*. 

97 

98 Args: 

99 body (InputResponseModel): The body of the message. 

100 msg (NatsMessage): The message object. 

101 logger (Logger): The logger object (gets injected) 

102 """ 

103 logger.info( 

104 f"Received message in subject '{self._input_receive_subject}': {body}" 

105 ) 

106 await msg.ack() 

107 self.queue.put(msg) 

108 

109 async def _send_error_msg(self, e: Exception, logger: Logger) -> None: 1abcdefghi

110 """Send an error message. 

111 

112 Args: 

113 e (Exception): The exception. 

114 logger (Logger): The logger object (gets injected) 

115 """ 

116 logger.error(f"Error in chat: {e}") 

117 logger.error(traceback.format_exc()) 

118 

119 error_msg = InputResponseModel(msg=str(e), error=True, question_uuid=None) 

120 await self.broker.publish(error_msg, self._input_request_subject) 

121 

122 def _create_initiate_subscriber(self) -> None: 1abcdefghi

123 @self.broker.subscriber( 

124 "chat.server.initiate_chat", 

125 stream=JETSTREAM, 

126 queue="initiate_workers", 

127 deliver_policy=api.DeliverPolicy("all"), 

128 ) 

129 async def initiate_handler( 

130 body: InitiateWorkflowModel, msg: NatsMessage, logger: Logger 

131 ) -> None: 

132 """Initiate the handler. 

133 

134 1. Subscribes to the chat.server.initiate_chat topic. 

135 2. When a message is consumed from the topic, it dynamically subscribes to the chat.server.messages.<user_uuid>.<workflow_uuid> topic. 

136 3. Starts the chat workflow after successfully subscribing to the chat.server.messages.<user_uuid>.<workflow_uuid> topic. 

137 

138 Args: 

139 body (InitiateModel): The body of the message. 

140 msg (NatsMessage): The message object. 

141 logger (Logger): The logger object (gets injected) 

142 

143 """ 

144 await msg.ack() 

145 

146 logger.info( 

147 f"Message in subject 'chat.server.initiate_chat': {body=} -> from process id {os.getpid()}" 

148 ) 

149 user_id = body.user_id if body.user_id else "None" 

150 workflow_uuid = body.workflow_uuid.hex 

151 self._input_request_subject = ( 

152 f"chat.client.messages.{user_id}.{workflow_uuid}" 

153 ) 

154 self._input_receive_subject = ( 

155 f"chat.server.messages.{user_id}.{workflow_uuid}" 

156 ) 

157 

158 # dynamically subscribe to the chat server 

159 subscriber = self.broker.subscriber( 

160 subject=self._input_receive_subject, 

161 stream=JETSTREAM, 

162 deliver_policy=api.DeliverPolicy("all"), 

163 ) 

164 subscriber(self._handle_input) 

165 self.broker.setup_subscriber(subscriber) 

166 await subscriber.start() 

167 

168 try: 

169 

170 async def start_chat( 

171 ui_base: UIBase, 

172 provider: ProviderProtocol, 

173 name: str, 

174 params: dict[str, Any], 

175 workflow_uuid: str, 

176 ) -> None: # type: ignore [return] 

177 def _start_chat( 

178 ui_base: UIBase, 

179 provider: ProviderProtocol, 

180 name: str, 

181 params: dict[str, Any], 

182 workflow_uuid: str, 

183 ) -> None: # type: ignore [return] 

184 ui: UI = ui_base.create_workflow_ui(workflow_uuid=workflow_uuid) 

185 try: 

186 provider.run( 

187 name=name, 

188 ui=ui, 

189 **params, 

190 ) 

191 except Exception as e: 

192 logger.error( 

193 f"Unexpecter error in NatsAdapter.start_chat: {e}", 

194 stack_info=True, 

195 ) 

196 ui.error( 

197 sender="NatsAdapter", 

198 short=f"Unexpected error: {e}", 

199 long=traceback.format_exc(), 

200 ) 

201 return 

202 

203 return await asyncify(_start_chat)( 

204 ui_base, provider, name, params, workflow_uuid 

205 ) 

206 

207 background_tasks = set() 

208 task = asyncio.create_task( 

209 start_chat( 

210 self, self.provider, body.name, body.params, workflow_uuid 

211 ) 

212 ) # type: ignore 

213 background_tasks.add(task) 

214 

215 async def callback(t: asyncio.Task[Any]) -> None: 

216 try: 

217 background_tasks.discard(t) 

218 await subscriber.close() 

219 except Exception as e: 

220 logger.error(f"Error in callback: {e}") 

221 logger.error(traceback.format_exc()) 

222 

223 task.add_done_callback(lambda t: asyncio.create_task(callback(t))) 

224 

225 except Exception as e: 

226 await self._send_error_msg(e, logger) 

227 

228 async def _publish_discovery(self) -> None: 1abcdefghi

229 """Publish the discovery message.""" 

230 jetstream_key_value = await self.broker.key_value(bucket="discovery") 

231 

232 names = self.provider.names 

233 for name in names: 

234 description = self.provider.get_description(name) 

235 await jetstream_key_value.put(name, description.encode()) 

236 

237 # todo: make it a router 

238 @asynccontextmanager 1abcdefghi

239 async def lifespan(self, app: Any) -> AsyncIterator[None]: 1abcdefghi

240 async with self.broker: 

241 await self.broker.start() 

242 await self._publish_discovery() 

243 try: 

244 yield 

245 finally: 

246 await self.broker.close() 

247 

248 def visit_default(self, message: IOMessage) -> None: 1abcdefghi

249 content = message.model_dump() 

250 logger.debug(f"visit_default(): {content=}") 

251 syncify(self.broker.publish)(content, self._input_request_subject) 

252 

253 def visit_text_message(self, message: TextMessage) -> None: 1abcdefghi

254 content = message.model_dump() 

255 logger.debug(f"visit_text_message(): {content=}") 

256 syncify(self.broker.publish)(content, self._input_request_subject) 

257 

258 async def _wait_for_question_response_with_timeout( 1abcdefghi

259 self, question_id: str, *, timeout: int = 180 

260 ) -> InputResponseModel: 

261 """Wait for the question response. 

262 

263 Args: 

264 question_id (str): The question ID. 

265 timeout (int, optional): The timeout in seconds. Defaults to 180. 

266 """ 

267 try: 

268 # Set a timeout of 180 seconds 

269 return await asyncio.wait_for( 

270 self._wait_for_question_response(question_id), timeout=timeout 

271 ) 

272 except asyncio.TimeoutError: 

273 logger.debug( 

274 f"Timeout: User did not send a reply within {timeout} seconds." 

275 ) 

276 return InputResponseModel( 

277 msg="User didn't send a reply. Exit the workflow execution.", 

278 question_uuid=question_id, 

279 error=True, 

280 ) 

281 

282 # todo: we need to add timeout and handle it somehow 

283 async def _wait_for_question_response(self, question_id: str) -> InputResponseModel: 1abcdefghi

284 while True: 

285 while self.queue.empty(): # noqa: ASYNC110 

286 await asyncio.sleep(0.1) 

287 

288 msg: NatsMessage = self.queue.get() 

289 input_response = InputResponseModel.model_validate_json( 

290 msg.raw_message.data.decode("utf-8") 

291 ) 

292 

293 question_id_hex = ( 

294 input_response.question_uuid.hex 

295 if input_response.question_uuid 

296 else "None" 

297 ) 

298 logger.debug(question_id_hex) 

299 logger.debug(question_id) 

300 logger.debug(question_id_hex == question_id) 

301 if question_id_hex == question_id: 

302 logger.debug("Breaking the while loop") 

303 break 

304 else: 

305 self.queue.put(msg) 

306 

307 logger.debug("Got the response") 

308 self.queue.task_done() 

309 return input_response 

310 

311 def visit_text_input(self, message: TextInput) -> str: 1abcdefghi

312 content = message.model_dump() 

313 question_id = message.uuid 

314 logger.info(f"visit_text_input(): {content=}") 

315 syncify(self.broker.publish)(content, self._input_request_subject) 

316 logger.info( 

317 f"visit_text_input(): published message '{content}' to {self._input_request_subject}" 

318 ) 

319 

320 input_response: InputResponseModel = syncify( 

321 self._wait_for_question_response_with_timeout 

322 )(question_id=question_id) 

323 logger.info(input_response) 

324 return input_response.msg 

325 

326 def visit_multiple_choice(self, message: MultipleChoice) -> str: 1abcdefghi

327 content = message.model_dump() 

328 question_id = message.uuid 

329 logger.info(f"visit_multiple_choice(): {content=}") 

330 syncify(self.broker.publish)(content, self._input_request_subject) 

331 

332 input_response: InputResponseModel = syncify( 

333 self._wait_for_question_response_with_timeout 

334 )(question_id=question_id) 

335 logger.info(input_response) 

336 return input_response.msg 

337 

338 def process_message(self, message: IOMessage) -> Optional[str]: 1abcdefghi

339 try: 

340 return self.visit(message) 

341 except Exception as e: 

342 logger.error(f"Error in process_message: {e}", stack_info=True) 

343 # do not reraise, we must go on 

344 if isinstance(message, AskingMessage): 

345 return "Error: Something went wrong. Please check logs for details." 

346 return None 

347 

348 def create_subconversation(self) -> "NatsAdapter": 1abcdefghi

349 return self 

350 

351 @classmethod 1abcdefghi

352 def create_provider( 1abcdefghi

353 cls, 

354 nats_url: Optional[str] = None, 1abcdefghi

355 user: Optional[str] = None, 1abcdefghi

356 password: Optional[str] = None, 1abcdefghi

357 ) -> ProviderProtocol: 1abcdefghi

358 return NatsProvider(nats_url=nats_url, user=user, password=password) 

359 

360 @contextmanager 1abcdefghi

361 def create(self, app: Runnable, import_string: str) -> Iterator[None]: 1abcdefghi

362 raise NotImplementedError("NatsAdapter.create() is not implemented") 

363 

364 def start( 1abcdefghi

365 self, 

366 *, 

367 app: Runnable, 

368 import_string: str, 

369 name: Optional[str] = None, 

370 params: dict[str, Any], 

371 single_run: bool = False, 

372 ) -> None: 

373 raise NotImplementedError("NatsAdapter.start() is not implemented") 

374 

375 

376class NatsProvider(ProviderProtocol): 1abcdefghi

377 def __init__( 1abcdefghi

378 self, 

379 nats_url: Optional[str] = None, 1abcdefghi

380 user: Optional[str] = None, 1abcdefghi

381 password: Optional[str] = None, 1abcdefghi

382 ) -> None: 1abcdefghi

383 """Initialize the nats workflows. 

384 

385 Args: 

386 nats_url (Optional[str], optional): The NATS URL. Defaults to None. 

387 user (Optional[str], optional): The user. Defaults to None. 

388 password (Optional[str], optional): The password. Defaults to None. 

389 """ 

390 self.nats_url = nats_url or "nats://localhost:4222" 

391 self.user = user 

392 self.password = password 

393 

394 self.broker = NatsBroker(self.nats_url, user=self.user, password=self.password) 

395 self.app = FastStream(self.broker) 

396 

397 self._initiate_chat_subject: str = "chat.server.initiate_chat" 

398 

399 self.is_broker_running: bool = False 

400 

401 async def _setup_subscriber( 1abcdefghi

402 self, ui: UI, from_server_subject: str, to_server_subject: str 

403 ) -> None: 

404 logger.info( 

405 f"Setting up subscriber for {from_server_subject=}, {to_server_subject=}" 

406 ) 

407 

408 async def consume_msg_from_nats(msg: dict[str, Any], logger: Logger) -> None: 

409 logger.debug(f"Received message from topic {from_server_subject}: {msg}") 

410 iomessage = ( 

411 IOMessage.create(**{"type": "error", "long": msg["msg"]}) 

412 if msg.get("error") 

413 else IOMessage.create(**msg) 

414 ) 

415 if isinstance(iomessage, AskingMessage): 

416 processed_message = ui.process_message(iomessage) 

417 response = InputResponseModel( 

418 msg=processed_message, question_uuid=iomessage.uuid 

419 ) 

420 logger.debug(f"Processed response: {response}") 

421 await self.broker.publish(response, to_server_subject) 

422 else: 

423 ui.process_message(iomessage) 

424 

425 subscriber = self.broker.subscriber( 

426 from_server_subject, 

427 stream=JETSTREAM, 

428 deliver_policy=api.DeliverPolicy("all"), 

429 ) 

430 subscriber(consume_msg_from_nats) 

431 self.broker.setup_subscriber(subscriber) 

432 await subscriber.start() 

433 logger.info(f"Subscriber for {from_server_subject} started") 

434 

435 def run( 1abcdefghi

436 self, 

437 name: str, 

438 ui: UI, 

439 user_id: Optional[str] = None, 

440 **kwargs: Any, 

441 ) -> str: 

442 # subscribe to whatever topic you need 

443 # consume a message from the topic and call that visitor pattern (which is happening in NatsProvider) 

444 workflow_uuid = ui._workflow_uuid 

445 init_message = InitiateWorkflowModel( 

446 user_id=user_id, 

447 workflow_uuid=workflow_uuid, 

448 params=kwargs, 

449 name=name, 

450 ) 

451 _from_server_subject = f"chat.client.messages.{user_id}.{workflow_uuid}" 

452 _to_server_subject = f"chat.server.messages.{user_id}.{workflow_uuid}" 

453 

454 async def send_initiate_chat_msg() -> None: 

455 await self.broker.publish(init_message, self._initiate_chat_subject) 

456 logger.info("Initiate chat message sent") 

457 

458 @asynccontextmanager 

459 async def lifespan() -> AsyncIterator[None]: 

460 async with self.broker: 

461 await self.broker.start() 

462 logger.debug("Broker started") 

463 try: 

464 yield 

465 finally: 

466 await self.broker.close() 

467 

468 async def _setup_and_run() -> None: 

469 await send_initiate_chat_msg() 

470 await self._setup_subscriber(ui, _from_server_subject, _to_server_subject) 

471 while True: # noqa: ASYNC110 

472 await asyncio.sleep(0.1) 

473 

474 async def run_lifespan() -> None: 

475 if not self.is_broker_running: 

476 self.is_broker_running = True 

477 async with lifespan(): 

478 await _setup_and_run() 

479 else: 

480 await _setup_and_run() 

481 

482 try: 

483 loop = asyncio.get_event_loop() 

484 except RuntimeError: 

485 loop = asyncio.new_event_loop() 

486 asyncio.set_event_loop(loop) 

487 

488 loop.run_until_complete(run_lifespan()) 

489 

490 return "NatsWorkflows.run() completed" 

491 

492 @asynccontextmanager 1abcdefghi

493 async def _get_jetstream_context(self) -> AsyncIterator[JetStreamContext]: 1abcdefghi

494 nc = NatsClient() 

495 await nc.connect(self.nats_url, user=self.user, password=self.password) 

496 js = nc.jetstream() 

497 try: 

498 yield js 

499 finally: 

500 await nc.close() 

501 

502 @asynccontextmanager 1abcdefghi

503 async def _get_jetstream_key_value( 1abcdefghi

504 self, bucket: str = "discovery" 

505 ) -> AsyncIterator[KeyValue]: 

506 async with self._get_jetstream_context() as js: 

507 kv = await js.create_key_value(bucket=bucket) 

508 yield kv 

509 

510 async def _get_names(self) -> list[str]: 1abcdefghi

511 try: 

512 async with self._get_jetstream_key_value() as kv: 

513 names = await kv.keys() 

514 except NoKeysError: 

515 names = [] 

516 except NoServersError as e: 

517 raise FastAgencyNATSConnectionError( 

518 f"Unable to connect to NATS server at {self.nats_url}" 

519 ) from e 

520 

521 return names 

522 

523 async def _get_description(self, name: str) -> str: 1abcdefghi

524 try: 

525 async with self._get_jetstream_key_value() as kv: 

526 description = await kv.get(name) 

527 return description.value.decode() if description.value else "" 

528 except KeyNotFoundError as e: 

529 raise FastAgencyNATSKeyError( 

530 f"Workflow name {name} not found to get description" 

531 ) from e 

532 except NoServersError as e: 

533 raise FastAgencyNATSConnectionError( 

534 f"Unable to connect to NATS server at {self.nats_url}" 

535 ) from e 

536 

537 @property 1abcdefghi

538 def names(self) -> list[str]: 1abcdefghi

539 names = asyncio.run(self._get_names()) 

540 logger.debug(f"Names: {names}") 

541 # return ["simple_learning"] 

542 return names 

543 

544 def get_description(self, name: str) -> str: 1abcdefghi

545 description = asyncio.run(self._get_description(name)) 

546 logger.debug(f"Description: {description}") 

547 # return "Student and teacher learning chat" 

548 return description