Coverage for fastagency/adapters/nats/base.py: 17%
180 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-19 12:16 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-19 12:16 +0000
1## IONats part
3import asyncio 1abcdefghi
4import os 1abcdefghi
5import traceback 1abcdefghi
6from collections.abc import AsyncIterator, Iterator 1abcdefghi
7from contextlib import asynccontextmanager, contextmanager 1abcdefghi
8from queue import Queue 1abcdefghi
9from typing import TYPE_CHECKING, Any, Optional 1abcdefghi
11from asyncer import asyncify, syncify 1abcdefghi
12from faststream import FastStream, Logger 1abcdefghi
13from faststream.nats import JStream, NatsBroker, NatsMessage 1abcdefghi
14from nats.aio.client import Client as NatsClient 1abcdefghi
15from nats.errors import NoServersError 1abcdefghi
16from nats.js import JetStreamContext, api 1abcdefghi
17from nats.js.errors import KeyNotFoundError, NoKeysError 1abcdefghi
18from nats.js.kv import KeyValue 1abcdefghi
20from ...base import UI, CreateWorkflowUIMixin, ProviderProtocol, Runnable, UIBase 1abcdefghi
21from ...exceptions import FastAgencyNATSConnectionError, FastAgencyNATSKeyError 1abcdefghi
22from ...logging import get_logger 1abcdefghi
23from ...messages import ( 1abcdefghi
24 AskingMessage,
25 IOMessage,
26 InitiateWorkflowModel,
27 InputResponseModel,
28 MessageProcessorMixin,
29 MultipleChoice,
30 TextInput,
31 TextMessage,
32)
34if TYPE_CHECKING: 1abcdefghi
35 from faststream.nats.subscriber.asyncapi import AsyncAPISubscriber
38logger = get_logger(__name__) 1abcdefghi
40JETSTREAM = JStream( 1abcdefghi
41 name="FastAgency",
42 subjects=[
43 # starts new conversation (can land on any worker)
44 "chat.server.initiate_chat",
45 # server requests input from client; chat.client.messages.<user_uuid>.<workflow_uuid>
46 # we create this topic dynamically => client process consuming NATS can fix its worker
47 "chat.client.messages.*.*",
48 # server prints message to client; chat.server.messages.<user_uuid>.<workflow_uuid>
49 # we create this topic dynamically and subscribe to it => worker is fixed
50 "chat.server.messages.*.*",
51 # discovery subject
52 "discovery",
53 ],
54)
57class NatsAdapter(MessageProcessorMixin, CreateWorkflowUIMixin): 1abcdefghi
58 def __init__( 1abcdefghi
59 self,
60 provider: ProviderProtocol, 1abcdefghi
61 *,
62 nats_url: Optional[str] = None, 1abcdefghi
63 user: Optional[str] = None, 1abcdefghi
64 password: Optional[str] = None, 1abcdefghi
65 super_conversation: Optional["NatsAdapter"] = None, 1abcdefghi
66 ) -> None: 1abcdefghi
67 """Provider for NATS.
69 Args:
70 provider (ProviderProtocol): The provider.
71 nats_url (Optional[str], optional): The NATS URL. Defaults to None in which case 'nats://localhost:4222' is used.
72 user (Optional[str], optional): The user. Defaults to None.
73 password (Optional[str], optional): The password. Defaults to None.
74 super_conversation (Optional["NatsProvider"], optional): The super conversation. Defaults to None.
75 """
76 self.provider = provider
77 self.nats_url = nats_url or "nats://localhost:4222"
78 self.user = user
79 self.password = password
80 self.queue: Queue = Queue() # type: ignore[type-arg]
82 self.broker = NatsBroker(self.nats_url, user=user, password=password)
83 self.app = FastStream(self.broker)
84 self.subscriber: "AsyncAPISubscriber"
85 self._input_request_subject: str
86 self._input_receive_subject: str
88 self.super_conversation: Optional[NatsAdapter] = super_conversation
89 self.sub_conversations: list[NatsAdapter] = []
91 self._create_initiate_subscriber()
93 async def _handle_input( 1abcdefghi
94 self, body: InputResponseModel, msg: NatsMessage, logger: Logger
95 ) -> None:
96 """Handle input from the client by consuming messages from chat.server.messages.*.*.
98 Args:
99 body (InputResponseModel): The body of the message.
100 msg (NatsMessage): The message object.
101 logger (Logger): The logger object (gets injected)
102 """
103 logger.info(
104 f"Received message in subject '{self._input_receive_subject}': {body}"
105 )
106 await msg.ack()
107 self.queue.put(msg)
109 async def _send_error_msg(self, e: Exception, logger: Logger) -> None: 1abcdefghi
110 """Send an error message.
112 Args:
113 e (Exception): The exception.
114 logger (Logger): The logger object (gets injected)
115 """
116 logger.error(f"Error in chat: {e}")
117 logger.error(traceback.format_exc())
119 error_msg = InputResponseModel(msg=str(e), error=True, question_uuid=None)
120 await self.broker.publish(error_msg, self._input_request_subject)
122 def _create_initiate_subscriber(self) -> None: 1abcdefghi
123 @self.broker.subscriber(
124 "chat.server.initiate_chat",
125 stream=JETSTREAM,
126 queue="initiate_workers",
127 deliver_policy=api.DeliverPolicy("all"),
128 )
129 async def initiate_handler(
130 body: InitiateWorkflowModel, msg: NatsMessage, logger: Logger
131 ) -> None:
132 """Initiate the handler.
134 1. Subscribes to the chat.server.initiate_chat topic.
135 2. When a message is consumed from the topic, it dynamically subscribes to the chat.server.messages.<user_uuid>.<workflow_uuid> topic.
136 3. Starts the chat workflow after successfully subscribing to the chat.server.messages.<user_uuid>.<workflow_uuid> topic.
138 Args:
139 body (InitiateModel): The body of the message.
140 msg (NatsMessage): The message object.
141 logger (Logger): The logger object (gets injected)
143 """
144 await msg.ack()
146 logger.info(
147 f"Message in subject 'chat.server.initiate_chat': {body=} -> from process id {os.getpid()}"
148 )
149 user_id = body.user_id if body.user_id else "None"
150 workflow_uuid = body.workflow_uuid.hex
151 self._input_request_subject = (
152 f"chat.client.messages.{user_id}.{workflow_uuid}"
153 )
154 self._input_receive_subject = (
155 f"chat.server.messages.{user_id}.{workflow_uuid}"
156 )
158 # dynamically subscribe to the chat server
159 subscriber = self.broker.subscriber(
160 subject=self._input_receive_subject,
161 stream=JETSTREAM,
162 deliver_policy=api.DeliverPolicy("all"),
163 )
164 subscriber(self._handle_input)
165 self.broker.setup_subscriber(subscriber)
166 await subscriber.start()
168 try:
170 async def start_chat(
171 ui_base: UIBase,
172 provider: ProviderProtocol,
173 name: str,
174 params: dict[str, Any],
175 workflow_uuid: str,
176 ) -> None: # type: ignore [return]
177 def _start_chat(
178 ui_base: UIBase,
179 provider: ProviderProtocol,
180 name: str,
181 params: dict[str, Any],
182 workflow_uuid: str,
183 ) -> None: # type: ignore [return]
184 ui: UI = ui_base.create_workflow_ui(workflow_uuid=workflow_uuid)
185 try:
186 provider.run(
187 name=name,
188 ui=ui,
189 **params,
190 )
191 except Exception as e:
192 logger.error(
193 f"Unexpecter error in NatsAdapter.start_chat: {e}",
194 stack_info=True,
195 )
196 ui.error(
197 sender="NatsAdapter",
198 short=f"Unexpected error: {e}",
199 long=traceback.format_exc(),
200 )
201 return
203 return await asyncify(_start_chat)(
204 ui_base, provider, name, params, workflow_uuid
205 )
207 background_tasks = set()
208 task = asyncio.create_task(
209 start_chat(
210 self, self.provider, body.name, body.params, workflow_uuid
211 )
212 ) # type: ignore
213 background_tasks.add(task)
215 async def callback(t: asyncio.Task[Any]) -> None:
216 try:
217 background_tasks.discard(t)
218 await subscriber.close()
219 except Exception as e:
220 logger.error(f"Error in callback: {e}")
221 logger.error(traceback.format_exc())
223 task.add_done_callback(lambda t: asyncio.create_task(callback(t)))
225 except Exception as e:
226 await self._send_error_msg(e, logger)
228 async def _publish_discovery(self) -> None: 1abcdefghi
229 """Publish the discovery message."""
230 jetstream_key_value = await self.broker.key_value(bucket="discovery")
232 names = self.provider.names
233 for name in names:
234 description = self.provider.get_description(name)
235 await jetstream_key_value.put(name, description.encode())
237 # todo: make it a router
238 @asynccontextmanager 1abcdefghi
239 async def lifespan(self, app: Any) -> AsyncIterator[None]: 1abcdefghi
240 async with self.broker:
241 await self.broker.start()
242 await self._publish_discovery()
243 try:
244 yield
245 finally:
246 await self.broker.close()
248 def visit_default(self, message: IOMessage) -> None: 1abcdefghi
249 content = message.model_dump()
250 logger.debug(f"visit_default(): {content=}")
251 syncify(self.broker.publish)(content, self._input_request_subject)
253 def visit_text_message(self, message: TextMessage) -> None: 1abcdefghi
254 content = message.model_dump()
255 logger.debug(f"visit_text_message(): {content=}")
256 syncify(self.broker.publish)(content, self._input_request_subject)
258 async def _wait_for_question_response_with_timeout( 1abcdefghi
259 self, question_id: str, *, timeout: int = 180
260 ) -> InputResponseModel:
261 """Wait for the question response.
263 Args:
264 question_id (str): The question ID.
265 timeout (int, optional): The timeout in seconds. Defaults to 180.
266 """
267 try:
268 # Set a timeout of 180 seconds
269 return await asyncio.wait_for(
270 self._wait_for_question_response(question_id), timeout=timeout
271 )
272 except asyncio.TimeoutError:
273 logger.debug(
274 f"Timeout: User did not send a reply within {timeout} seconds."
275 )
276 return InputResponseModel(
277 msg="User didn't send a reply. Exit the workflow execution.",
278 question_uuid=question_id,
279 error=True,
280 )
282 # todo: we need to add timeout and handle it somehow
283 async def _wait_for_question_response(self, question_id: str) -> InputResponseModel: 1abcdefghi
284 while True:
285 while self.queue.empty(): # noqa: ASYNC110
286 await asyncio.sleep(0.1)
288 msg: NatsMessage = self.queue.get()
289 input_response = InputResponseModel.model_validate_json(
290 msg.raw_message.data.decode("utf-8")
291 )
293 question_id_hex = (
294 input_response.question_uuid.hex
295 if input_response.question_uuid
296 else "None"
297 )
298 logger.debug(question_id_hex)
299 logger.debug(question_id)
300 logger.debug(question_id_hex == question_id)
301 if question_id_hex == question_id:
302 logger.debug("Breaking the while loop")
303 break
304 else:
305 self.queue.put(msg)
307 logger.debug("Got the response")
308 self.queue.task_done()
309 return input_response
311 def visit_text_input(self, message: TextInput) -> str: 1abcdefghi
312 content = message.model_dump()
313 question_id = message.uuid
314 logger.info(f"visit_text_input(): {content=}")
315 syncify(self.broker.publish)(content, self._input_request_subject)
316 logger.info(
317 f"visit_text_input(): published message '{content}' to {self._input_request_subject}"
318 )
320 input_response: InputResponseModel = syncify(
321 self._wait_for_question_response_with_timeout
322 )(question_id=question_id)
323 logger.info(input_response)
324 return input_response.msg
326 def visit_multiple_choice(self, message: MultipleChoice) -> str: 1abcdefghi
327 content = message.model_dump()
328 question_id = message.uuid
329 logger.info(f"visit_multiple_choice(): {content=}")
330 syncify(self.broker.publish)(content, self._input_request_subject)
332 input_response: InputResponseModel = syncify(
333 self._wait_for_question_response_with_timeout
334 )(question_id=question_id)
335 logger.info(input_response)
336 return input_response.msg
338 def process_message(self, message: IOMessage) -> Optional[str]: 1abcdefghi
339 try:
340 return self.visit(message)
341 except Exception as e:
342 logger.error(f"Error in process_message: {e}", stack_info=True)
343 # do not reraise, we must go on
344 if isinstance(message, AskingMessage):
345 return "Error: Something went wrong. Please check logs for details."
346 return None
348 def create_subconversation(self) -> "NatsAdapter": 1abcdefghi
349 return self
351 @classmethod 1abcdefghi
352 def create_provider( 1abcdefghi
353 cls,
354 nats_url: Optional[str] = None, 1abcdefghi
355 user: Optional[str] = None, 1abcdefghi
356 password: Optional[str] = None, 1abcdefghi
357 ) -> ProviderProtocol: 1abcdefghi
358 return NatsProvider(nats_url=nats_url, user=user, password=password)
360 @contextmanager 1abcdefghi
361 def create(self, app: Runnable, import_string: str) -> Iterator[None]: 1abcdefghi
362 raise NotImplementedError("NatsAdapter.create() is not implemented")
364 def start( 1abcdefghi
365 self,
366 *,
367 app: Runnable,
368 import_string: str,
369 name: Optional[str] = None,
370 params: dict[str, Any],
371 single_run: bool = False,
372 ) -> None:
373 raise NotImplementedError("NatsAdapter.start() is not implemented")
376class NatsProvider(ProviderProtocol): 1abcdefghi
377 def __init__( 1abcdefghi
378 self,
379 nats_url: Optional[str] = None, 1abcdefghi
380 user: Optional[str] = None, 1abcdefghi
381 password: Optional[str] = None, 1abcdefghi
382 ) -> None: 1abcdefghi
383 """Initialize the nats workflows.
385 Args:
386 nats_url (Optional[str], optional): The NATS URL. Defaults to None.
387 user (Optional[str], optional): The user. Defaults to None.
388 password (Optional[str], optional): The password. Defaults to None.
389 """
390 self.nats_url = nats_url or "nats://localhost:4222"
391 self.user = user
392 self.password = password
394 self.broker = NatsBroker(self.nats_url, user=self.user, password=self.password)
395 self.app = FastStream(self.broker)
397 self._initiate_chat_subject: str = "chat.server.initiate_chat"
399 self.is_broker_running: bool = False
401 async def _setup_subscriber( 1abcdefghi
402 self, ui: UI, from_server_subject: str, to_server_subject: str
403 ) -> None:
404 logger.info(
405 f"Setting up subscriber for {from_server_subject=}, {to_server_subject=}"
406 )
408 async def consume_msg_from_nats(msg: dict[str, Any], logger: Logger) -> None:
409 logger.debug(f"Received message from topic {from_server_subject}: {msg}")
410 iomessage = (
411 IOMessage.create(**{"type": "error", "long": msg["msg"]})
412 if msg.get("error")
413 else IOMessage.create(**msg)
414 )
415 if isinstance(iomessage, AskingMessage):
416 processed_message = ui.process_message(iomessage)
417 response = InputResponseModel(
418 msg=processed_message, question_uuid=iomessage.uuid
419 )
420 logger.debug(f"Processed response: {response}")
421 await self.broker.publish(response, to_server_subject)
422 else:
423 ui.process_message(iomessage)
425 subscriber = self.broker.subscriber(
426 from_server_subject,
427 stream=JETSTREAM,
428 deliver_policy=api.DeliverPolicy("all"),
429 )
430 subscriber(consume_msg_from_nats)
431 self.broker.setup_subscriber(subscriber)
432 await subscriber.start()
433 logger.info(f"Subscriber for {from_server_subject} started")
435 def run( 1abcdefghi
436 self,
437 name: str,
438 ui: UI,
439 user_id: Optional[str] = None,
440 **kwargs: Any,
441 ) -> str:
442 # subscribe to whatever topic you need
443 # consume a message from the topic and call that visitor pattern (which is happening in NatsProvider)
444 workflow_uuid = ui._workflow_uuid
445 init_message = InitiateWorkflowModel(
446 user_id=user_id,
447 workflow_uuid=workflow_uuid,
448 params=kwargs,
449 name=name,
450 )
451 _from_server_subject = f"chat.client.messages.{user_id}.{workflow_uuid}"
452 _to_server_subject = f"chat.server.messages.{user_id}.{workflow_uuid}"
454 async def send_initiate_chat_msg() -> None:
455 await self.broker.publish(init_message, self._initiate_chat_subject)
456 logger.info("Initiate chat message sent")
458 @asynccontextmanager
459 async def lifespan() -> AsyncIterator[None]:
460 async with self.broker:
461 await self.broker.start()
462 logger.debug("Broker started")
463 try:
464 yield
465 finally:
466 await self.broker.close()
468 async def _setup_and_run() -> None:
469 await send_initiate_chat_msg()
470 await self._setup_subscriber(ui, _from_server_subject, _to_server_subject)
471 while True: # noqa: ASYNC110
472 await asyncio.sleep(0.1)
474 async def run_lifespan() -> None:
475 if not self.is_broker_running:
476 self.is_broker_running = True
477 async with lifespan():
478 await _setup_and_run()
479 else:
480 await _setup_and_run()
482 try:
483 loop = asyncio.get_event_loop()
484 except RuntimeError:
485 loop = asyncio.new_event_loop()
486 asyncio.set_event_loop(loop)
488 loop.run_until_complete(run_lifespan())
490 return "NatsWorkflows.run() completed"
492 @asynccontextmanager 1abcdefghi
493 async def _get_jetstream_context(self) -> AsyncIterator[JetStreamContext]: 1abcdefghi
494 nc = NatsClient()
495 await nc.connect(self.nats_url, user=self.user, password=self.password)
496 js = nc.jetstream()
497 try:
498 yield js
499 finally:
500 await nc.close()
502 @asynccontextmanager 1abcdefghi
503 async def _get_jetstream_key_value( 1abcdefghi
504 self, bucket: str = "discovery"
505 ) -> AsyncIterator[KeyValue]:
506 async with self._get_jetstream_context() as js:
507 kv = await js.create_key_value(bucket=bucket)
508 yield kv
510 async def _get_names(self) -> list[str]: 1abcdefghi
511 try:
512 async with self._get_jetstream_key_value() as kv:
513 names = await kv.keys()
514 except NoKeysError:
515 names = []
516 except NoServersError as e:
517 raise FastAgencyNATSConnectionError(
518 f"Unable to connect to NATS server at {self.nats_url}"
519 ) from e
521 return names
523 async def _get_description(self, name: str) -> str: 1abcdefghi
524 try:
525 async with self._get_jetstream_key_value() as kv:
526 description = await kv.get(name)
527 return description.value.decode() if description.value else ""
528 except KeyNotFoundError as e:
529 raise FastAgencyNATSKeyError(
530 f"Workflow name {name} not found to get description"
531 ) from e
532 except NoServersError as e:
533 raise FastAgencyNATSConnectionError(
534 f"Unable to connect to NATS server at {self.nats_url}"
535 ) from e
537 @property 1abcdefghi
538 def names(self) -> list[str]: 1abcdefghi
539 names = asyncio.run(self._get_names())
540 logger.debug(f"Names: {names}")
541 # return ["simple_learning"]
542 return names
544 def get_description(self, name: str) -> str: 1abcdefghi
545 description = asyncio.run(self._get_description(name))
546 logger.debug(f"Description: {description}")
547 # return "Student and teacher learning chat"
548 return description