Coverage for fastagency/adapters/fastapi/base.py: 54%
145 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-19 12:16 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-19 12:16 +0000
1## IONats part
3import asyncio 1aefghibcd
4import json 1aefghibcd
5from collections.abc import Iterator 1aefghibcd
6from contextlib import AsyncExitStack, contextmanager 1aefghibcd
7from typing import Any, Callable, Optional 1aefghibcd
8from uuid import UUID, uuid4 1aefghibcd
10import requests 1aefghibcd
11import websockets 1aefghibcd
12from asyncer import asyncify, syncify 1aefghibcd
13from fastapi import ( 1aefghibcd
14 APIRouter,
15 Depends,
16 HTTPException,
17 Request,
18 Response,
19 WebSocket,
20)
21from fastapi.dependencies.utils import get_dependant, solve_dependencies 1aefghibcd
22from pydantic import BaseModel 1aefghibcd
24from fastagency.logging import get_logger 1aefghibcd
26from ...base import ( 1aefghibcd
27 UI,
28 CreateWorkflowUIMixin,
29 ProviderProtocol,
30 Runnable,
31 UIBase,
32 WorkflowsProtocol,
33)
34from ...exceptions import ( 1aefghibcd
35 FastAgencyConnectionError,
36 FastAgencyFastAPIConnectionError,
37 FastAgencyKeyError,
38)
39from ...messages import ( 1aefghibcd
40 AskingMessage,
41 IOMessage,
42 InitiateWorkflowModel,
43 InputResponseModel,
44 MessageProcessorMixin,
45)
47logger = get_logger(__name__) 1aefghibcd
50class InititateChatModel(BaseModel): 1aefghibcd
51 workflow_name: str 1aefghibcd
52 workflow_uuid: str 1aefghibcd
53 user_id: Optional[str] 1aefghibcd
54 params: dict[str, Any] 1aefghibcd
57class WorkflowInfo(BaseModel): 1aefghibcd
58 name: str 1aefghibcd
59 description: str 1aefghibcd
62class FastAPIAdapter(MessageProcessorMixin, CreateWorkflowUIMixin): 1aefghibcd
63 def __init__( 1aefghibcd
64 self,
65 provider: ProviderProtocol, 1aefghibcd
66 *,
67 initiate_workflow_path: str = "/fastagency/initiate_workflow", 1aefghibcd
68 discovery_path: str = "/fastagency/discovery", 1aefghibcd
69 ws_path: str = "/fastagency/ws", 1aefghibcd
70 get_user_id: Optional[Callable[..., Optional[str]]] = None, 1aefghibcd
71 ) -> None: 1aefghibcd
72 """Provider for FastAPI.
74 Args:
75 provider (ProviderProtocol): The provider.
76 initiate_workflow_path (str, optional): The initiate workflow path. Defaults to "/fastagency/initiate_workflow".
77 discovery_path (str, optional): The discovery path. Defaults to "/fastagency/discovery".
78 ws_path (str, optional): The websocket path. Defaults to "/fastagency/ws".
79 get_user_id (Optional[Callable[[], Optional[UUID]]], optional): The get user id. Defaults to None.
80 """
81 self.provider = provider 1aefghibcd
83 self.initiate_workflow_path = initiate_workflow_path 1aefghibcd
84 self.discovery_path = discovery_path 1aefghibcd
85 self.ws_path = ws_path 1aefghibcd
87 self.get_user_id = get_user_id or (lambda: None) 1aefghibcd
89 self.websockets: dict[str, WebSocket] = {} 1aefghibcd
91 self.router = self.setup_routes() 1aefghibcd
93 async def get_user_id_websocket(self, websocket: WebSocket) -> Optional[str]: 1aefghibcd
94 def get_user_id_depends_stub( 1abcd
95 user_id: Optional[str] = Depends(self.get_user_id),
96 ) -> Optional[str]:
97 raise RuntimeError(
98 "Stub get_user_id_depends_stub called"
99 ) # pragma: no cover
101 dependant = get_dependant(path="", call=get_user_id_depends_stub) 1abcd
103 try: 1abcd
104 async with AsyncExitStack() as cm: 1abcd
105 scope = websocket.scope 1abcd
106 scope["type"] = "http" 1abcd
108 solved_dependency = await solve_dependencies( 1abcd
109 dependant=dependant,
110 request=Request(scope=scope), # Inject the request here
111 body=None,
112 dependency_overrides_provider=None,
113 async_exit_stack=cm,
114 embed_body_fields=False,
115 )
116 except HTTPException as e: 1abcd
117 raise e 1abcd
118 finally:
119 scope["type"] = "websocket" 1abcd
121 return solved_dependency.values["user_id"] # type: ignore[no-any-return] 1abcd
123 def setup_routes(self) -> APIRouter: 1aefghibcd
124 router = APIRouter() 1aefghibcd
126 @router.post(self.initiate_workflow_path) 1aefghibcd
127 async def initiate_chat( 1aefghibcd
128 initiate_chat: InititateChatModel,
129 user_id: Optional[str] = Depends(self.get_user_id),
130 ) -> InitiateWorkflowModel:
131 workflow_uuid: UUID = uuid4() 1abcd
133 init_msg = InitiateWorkflowModel( 1abcd
134 user_id=user_id,
135 workflow_uuid=workflow_uuid,
136 params=initiate_chat.params,
137 name=initiate_chat.workflow_name,
138 )
140 return init_msg 1abcd
142 @router.websocket(self.ws_path) 1aefghibcd
143 async def websocket_endpoint( 1aefghibcd
144 websocket: WebSocket,
145 ) -> None:
146 try: 1abcd
147 user_id = await self.get_user_id_websocket(websocket) 1abcd
148 except HTTPException as e: 1abcd
149 headers = getattr(e, "headers", None) 1abcd
150 await websocket.send_denial_response( 1abcd
151 Response(status_code=e.status_code, headers=headers)
152 )
153 return 1abcd
155 logger.info("Websocket connected") 1abcd
156 await websocket.accept() 1abcd
157 logger.info("Websocket accepted") 1abcd
159 init_msg_json = await websocket.receive_text() 1abcd
160 logger.info(f"Received message: {init_msg_json}") 1abcd
162 init_msg = InitiateWorkflowModel.model_validate_json(init_msg_json) 1abcd
164 workflow_uuid = init_msg.workflow_uuid.hex 1abcd
165 self.websockets[workflow_uuid] = websocket 1abcd
167 try: 1abcd
168 await asyncify(self.provider.run)( 1abcd
169 name=init_msg.name,
170 ui=self.create_workflow_ui(workflow_uuid),
171 user_id=user_id if user_id else "None",
172 **init_msg.params,
173 )
174 except Exception as e:
175 logger.error(f"Error in websocket_endpoint: {e}", stack_info=True)
176 finally:
177 self.websockets.pop(workflow_uuid) 1abcd
179 @router.get( 1aefghibcd
180 self.discovery_path,
181 responses={
182 404: {"detail": "Key Not Found"},
183 504: {"detail": "Unable to connect to provider"},
184 },
185 )
186 def discovery( 1aefghibcd
187 user_id: Optional[str] = Depends(self.get_user_id),
188 ) -> list[WorkflowInfo]:
189 try: 1abcd
190 names = self.provider.names 1abcd
191 except FastAgencyConnectionError as e:
192 raise HTTPException(status_code=504, detail=str(e)) from e
194 try: 1abcd
195 descriptions = [self.provider.get_description(name) for name in names] 1abcd
196 except FastAgencyKeyError as e:
197 raise HTTPException(status_code=404, detail=str(e)) from e
199 return [ 1abcd
200 WorkflowInfo(name=name, description=description)
201 for name, description in zip(names, descriptions)
202 ]
204 return router 1aefghibcd
206 def visit_default(self, message: IOMessage) -> Optional[str]: 1aefghibcd
207 async def a_visit_default( 1abcd
208 self: FastAPIAdapter, message: IOMessage
209 ) -> Optional[str]:
210 workflow_uuid = message.workflow_uuid 1abcd
211 if workflow_uuid not in self.websockets: 1abcd
212 logger.error(
213 f"Workflow {workflow_uuid} not found in websockets: {self.websockets}"
214 )
215 raise RuntimeError(
216 f"Workflow {workflow_uuid} not found in websockets: {self.websockets}"
217 )
218 websocket = self.websockets[workflow_uuid] # type: ignore[index] 1abcd
219 await websocket.send_text(json.dumps(message.model_dump())) 1abcd
221 if isinstance(message, AskingMessage): 1abcd
222 response = await websocket.receive_text() 1abcd
223 return response
224 return None 1abcd
226 return syncify(a_visit_default)(self, message) 1abcd
228 def create_subconversation(self) -> UIBase: 1aefghibcd
229 return self
231 @contextmanager 1aefghibcd
232 def create(self, app: Runnable, import_string: str) -> Iterator[None]: 1aefghibcd
233 raise NotImplementedError("create")
235 def start( 1aefghibcd
236 self,
237 *,
238 app: "Runnable",
239 import_string: str,
240 name: Optional[str] = None,
241 params: dict[str, Any],
242 single_run: bool = False,
243 ) -> None:
244 raise NotImplementedError("start")
246 @classmethod 1aefghibcd
247 def create_provider( 1aefghibcd
248 cls,
249 fastapi_url: str,
250 ) -> ProviderProtocol:
251 return FastAPIProvider(
252 fastapi_url=fastapi_url,
253 )
256class FastAPIProvider(ProviderProtocol): 1aefghibcd
257 def __init__( 1aefghibcd
258 self,
259 fastapi_url: str,
260 initiate_workflow_path: str = "/fastagency/initiate_workflow",
261 discovery_path: str = "/fastagency/discovery",
262 ws_path: str = "/fastagency/ws",
263 ) -> None:
264 """Initialize the fastapi workflows."""
265 self._workflows: dict[
266 str, tuple[Callable[[WorkflowsProtocol, UIBase, str, str], str], str]
267 ] = {}
269 self.fastapi_url = (
270 fastapi_url[:-1] if fastapi_url.endswith("/") else fastapi_url
271 )
272 self.ws_url = "ws" + self.fastapi_url[4:]
274 self.is_broker_running: bool = False
276 self.initiate_workflow_path = initiate_workflow_path
277 self.discovery_path = discovery_path
278 self.ws_path = ws_path
280 def _send_initiate_chat_msg( 1aefghibcd
281 self,
282 workflow_name: str,
283 workflow_uuid: str,
284 user_id: Optional[str],
285 params: dict[str, Any],
286 ) -> InitiateWorkflowModel:
287 msg = InititateChatModel(
288 workflow_name=workflow_name,
289 workflow_uuid=workflow_uuid,
290 user_id=user_id,
291 params=params,
292 )
294 payload = msg.model_dump()
296 resp = requests.post(
297 f"{self.fastapi_url}{self.initiate_workflow_path}", json=payload, timeout=5
298 )
299 logger.info(f"Initiate chat response: {resp.json()}")
300 retval = InitiateWorkflowModel(**resp.json())
301 return retval
303 async def _publish_websocket_message( 1aefghibcd
304 self,
305 websocket: websockets.WebSocketClientProtocol,
306 message: InputResponseModel,
307 ) -> None:
308 payload = message.model_dump_json()
309 await websocket.send(payload)
310 logger.info(f"Message sent to websocket ({websocket}): {message}")
312 async def _run_websocket_subscriber( 1aefghibcd
313 self,
314 ui: UI,
315 workflow_name: str,
316 user_id: Optional[str],
317 from_server_subject: str,
318 to_server_subject: str,
319 params: dict[str, Any],
320 ) -> None:
321 connect_url = f"{self.ws_url}{self.ws_path}"
322 async with websockets.connect(connect_url) as websocket:
323 init_workflow_msg = InitiateWorkflowModel(
324 name=workflow_name,
325 workflow_uuid=ui._workflow_uuid,
326 user_id=user_id,
327 params=params,
328 )
329 await websocket.send(init_workflow_msg.model_dump_json())
331 while True:
332 response = await websocket.recv()
333 response = (
334 response.decode() if isinstance(response, bytes) else response
335 )
337 logger.info(f"Received message: {response}")
339 msg = IOMessage.create(**json.loads(response))
341 retval = await asyncify(ui.process_message)(msg)
342 logger.info(f"Message {msg}: processed with response {retval}")
344 if isinstance(msg, AskingMessage):
345 if retval is None:
346 logger.warning(
347 f"Message {msg}: response is None. Skipping response to websocket"
348 )
349 else:
350 await websocket.send(retval)
351 logger.info(
352 f"Message {msg}: response {retval} sent to websocket"
353 )
355 def run( 1aefghibcd
356 self,
357 name: str,
358 ui: UI,
359 user_id: Optional[str] = None,
360 **kwargs: Any,
361 ) -> str:
362 workflow_uuid = ui._workflow_uuid
364 initiate_workflow = self._send_initiate_chat_msg(
365 name, workflow_uuid=workflow_uuid, user_id=user_id, params=kwargs
366 )
367 user_id = initiate_workflow.user_id if initiate_workflow.user_id else "None"
368 workflow_uuid = initiate_workflow.workflow_uuid.hex
370 _from_server_subject = f"chat.client.messages.{user_id}.{workflow_uuid}"
371 _to_server_subject = f"chat.server.messages.{user_id}.{workflow_uuid}"
373 async def _setup_and_run() -> None:
374 await self._run_websocket_subscriber(
375 ui,
376 name,
377 user_id,
378 _from_server_subject,
379 _to_server_subject,
380 kwargs,
381 )
383 async def run_lifespan() -> None:
384 if not self.is_broker_running:
385 self.is_broker_running = True
386 await _setup_and_run()
387 else:
388 await _setup_and_run()
390 try:
391 loop = asyncio.get_event_loop()
392 except RuntimeError:
393 loop = asyncio.new_event_loop()
394 asyncio.set_event_loop(loop)
396 loop.run_until_complete(run_lifespan())
398 return "FastAPIWorkflows.run() completed"
400 def _get_workflow_info(self) -> list[dict[str, str]]: 1aefghibcd
401 try:
402 resp = requests.get(f"{self.fastapi_url}/{self.discovery_path}", timeout=15)
403 except requests.exceptions.ConnectionError as e:
404 raise FastAgencyFastAPIConnectionError(
405 f"Unable to connect to FastAPI server at {self.fastapi_url}"
406 ) from e
407 if resp.status_code == 504:
408 raise FastAgencyConnectionError(resp.json()["detail"])
409 elif resp.status_code == 404:
410 raise FastAgencyKeyError(resp.json()["detail"])
411 return resp.json() # type: ignore [no-any-return]
413 def _get_names(self) -> list[str]: 1aefghibcd
414 return [workflow["name"] for workflow in self._get_workflow_info()]
416 def _get_description(self, name: str) -> str: 1aefghibcd
417 return next(
418 workflow["description"]
419 for workflow in self._get_workflow_info()
420 if workflow["name"] == name
421 )
423 @property 1aefghibcd
424 def names(self) -> list[str]: 1aefghibcd
425 return self._get_names()
427 def get_description(self, name: str) -> str: 1aefghibcd
428 return self._get_description(name)