Coverage for fastagency/adapters/fastapi/base.py: 54%

145 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-19 12:16 +0000

1## IONats part 

2 

3import asyncio 1aefghibcd

4import json 1aefghibcd

5from collections.abc import Iterator 1aefghibcd

6from contextlib import AsyncExitStack, contextmanager 1aefghibcd

7from typing import Any, Callable, Optional 1aefghibcd

8from uuid import UUID, uuid4 1aefghibcd

9 

10import requests 1aefghibcd

11import websockets 1aefghibcd

12from asyncer import asyncify, syncify 1aefghibcd

13from fastapi import ( 1aefghibcd

14 APIRouter, 

15 Depends, 

16 HTTPException, 

17 Request, 

18 Response, 

19 WebSocket, 

20) 

21from fastapi.dependencies.utils import get_dependant, solve_dependencies 1aefghibcd

22from pydantic import BaseModel 1aefghibcd

23 

24from fastagency.logging import get_logger 1aefghibcd

25 

26from ...base import ( 1aefghibcd

27 UI, 

28 CreateWorkflowUIMixin, 

29 ProviderProtocol, 

30 Runnable, 

31 UIBase, 

32 WorkflowsProtocol, 

33) 

34from ...exceptions import ( 1aefghibcd

35 FastAgencyConnectionError, 

36 FastAgencyFastAPIConnectionError, 

37 FastAgencyKeyError, 

38) 

39from ...messages import ( 1aefghibcd

40 AskingMessage, 

41 IOMessage, 

42 InitiateWorkflowModel, 

43 InputResponseModel, 

44 MessageProcessorMixin, 

45) 

46 

47logger = get_logger(__name__) 1aefghibcd

48 

49 

50class InititateChatModel(BaseModel): 1aefghibcd

51 workflow_name: str 1aefghibcd

52 workflow_uuid: str 1aefghibcd

53 user_id: Optional[str] 1aefghibcd

54 params: dict[str, Any] 1aefghibcd

55 

56 

57class WorkflowInfo(BaseModel): 1aefghibcd

58 name: str 1aefghibcd

59 description: str 1aefghibcd

60 

61 

62class FastAPIAdapter(MessageProcessorMixin, CreateWorkflowUIMixin): 1aefghibcd

63 def __init__( 1aefghibcd

64 self, 

65 provider: ProviderProtocol, 1aefghibcd

66 *, 

67 initiate_workflow_path: str = "/fastagency/initiate_workflow", 1aefghibcd

68 discovery_path: str = "/fastagency/discovery", 1aefghibcd

69 ws_path: str = "/fastagency/ws", 1aefghibcd

70 get_user_id: Optional[Callable[..., Optional[str]]] = None, 1aefghibcd

71 ) -> None: 1aefghibcd

72 """Provider for FastAPI. 

73 

74 Args: 

75 provider (ProviderProtocol): The provider. 

76 initiate_workflow_path (str, optional): The initiate workflow path. Defaults to "/fastagency/initiate_workflow". 

77 discovery_path (str, optional): The discovery path. Defaults to "/fastagency/discovery". 

78 ws_path (str, optional): The websocket path. Defaults to "/fastagency/ws". 

79 get_user_id (Optional[Callable[[], Optional[UUID]]], optional): The get user id. Defaults to None. 

80 """ 

81 self.provider = provider 1aefghibcd

82 

83 self.initiate_workflow_path = initiate_workflow_path 1aefghibcd

84 self.discovery_path = discovery_path 1aefghibcd

85 self.ws_path = ws_path 1aefghibcd

86 

87 self.get_user_id = get_user_id or (lambda: None) 1aefghibcd

88 

89 self.websockets: dict[str, WebSocket] = {} 1aefghibcd

90 

91 self.router = self.setup_routes() 1aefghibcd

92 

93 async def get_user_id_websocket(self, websocket: WebSocket) -> Optional[str]: 1aefghibcd

94 def get_user_id_depends_stub( 1abcd

95 user_id: Optional[str] = Depends(self.get_user_id), 

96 ) -> Optional[str]: 

97 raise RuntimeError( 

98 "Stub get_user_id_depends_stub called" 

99 ) # pragma: no cover 

100 

101 dependant = get_dependant(path="", call=get_user_id_depends_stub) 1abcd

102 

103 try: 1abcd

104 async with AsyncExitStack() as cm: 1abcd

105 scope = websocket.scope 1abcd

106 scope["type"] = "http" 1abcd

107 

108 solved_dependency = await solve_dependencies( 1abcd

109 dependant=dependant, 

110 request=Request(scope=scope), # Inject the request here 

111 body=None, 

112 dependency_overrides_provider=None, 

113 async_exit_stack=cm, 

114 embed_body_fields=False, 

115 ) 

116 except HTTPException as e: 1abcd

117 raise e 1abcd

118 finally: 

119 scope["type"] = "websocket" 1abcd

120 

121 return solved_dependency.values["user_id"] # type: ignore[no-any-return] 1abcd

122 

123 def setup_routes(self) -> APIRouter: 1aefghibcd

124 router = APIRouter() 1aefghibcd

125 

126 @router.post(self.initiate_workflow_path) 1aefghibcd

127 async def initiate_chat( 1aefghibcd

128 initiate_chat: InititateChatModel, 

129 user_id: Optional[str] = Depends(self.get_user_id), 

130 ) -> InitiateWorkflowModel: 

131 workflow_uuid: UUID = uuid4() 1abcd

132 

133 init_msg = InitiateWorkflowModel( 1abcd

134 user_id=user_id, 

135 workflow_uuid=workflow_uuid, 

136 params=initiate_chat.params, 

137 name=initiate_chat.workflow_name, 

138 ) 

139 

140 return init_msg 1abcd

141 

142 @router.websocket(self.ws_path) 1aefghibcd

143 async def websocket_endpoint( 1aefghibcd

144 websocket: WebSocket, 

145 ) -> None: 

146 try: 1abcd

147 user_id = await self.get_user_id_websocket(websocket) 1abcd

148 except HTTPException as e: 1abcd

149 headers = getattr(e, "headers", None) 1abcd

150 await websocket.send_denial_response( 1abcd

151 Response(status_code=e.status_code, headers=headers) 

152 ) 

153 return 1abcd

154 

155 logger.info("Websocket connected") 1abcd

156 await websocket.accept() 1abcd

157 logger.info("Websocket accepted") 1abcd

158 

159 init_msg_json = await websocket.receive_text() 1abcd

160 logger.info(f"Received message: {init_msg_json}") 1abcd

161 

162 init_msg = InitiateWorkflowModel.model_validate_json(init_msg_json) 1abcd

163 

164 workflow_uuid = init_msg.workflow_uuid.hex 1abcd

165 self.websockets[workflow_uuid] = websocket 1abcd

166 

167 try: 1abcd

168 await asyncify(self.provider.run)( 1abcd

169 name=init_msg.name, 

170 ui=self.create_workflow_ui(workflow_uuid), 

171 user_id=user_id if user_id else "None", 

172 **init_msg.params, 

173 ) 

174 except Exception as e: 

175 logger.error(f"Error in websocket_endpoint: {e}", stack_info=True) 

176 finally: 

177 self.websockets.pop(workflow_uuid) 1abcd

178 

179 @router.get( 1aefghibcd

180 self.discovery_path, 

181 responses={ 

182 404: {"detail": "Key Not Found"}, 

183 504: {"detail": "Unable to connect to provider"}, 

184 }, 

185 ) 

186 def discovery( 1aefghibcd

187 user_id: Optional[str] = Depends(self.get_user_id), 

188 ) -> list[WorkflowInfo]: 

189 try: 1abcd

190 names = self.provider.names 1abcd

191 except FastAgencyConnectionError as e: 

192 raise HTTPException(status_code=504, detail=str(e)) from e 

193 

194 try: 1abcd

195 descriptions = [self.provider.get_description(name) for name in names] 1abcd

196 except FastAgencyKeyError as e: 

197 raise HTTPException(status_code=404, detail=str(e)) from e 

198 

199 return [ 1abcd

200 WorkflowInfo(name=name, description=description) 

201 for name, description in zip(names, descriptions) 

202 ] 

203 

204 return router 1aefghibcd

205 

206 def visit_default(self, message: IOMessage) -> Optional[str]: 1aefghibcd

207 async def a_visit_default( 1abcd

208 self: FastAPIAdapter, message: IOMessage 

209 ) -> Optional[str]: 

210 workflow_uuid = message.workflow_uuid 1abcd

211 if workflow_uuid not in self.websockets: 1abcd

212 logger.error( 

213 f"Workflow {workflow_uuid} not found in websockets: {self.websockets}" 

214 ) 

215 raise RuntimeError( 

216 f"Workflow {workflow_uuid} not found in websockets: {self.websockets}" 

217 ) 

218 websocket = self.websockets[workflow_uuid] # type: ignore[index] 1abcd

219 await websocket.send_text(json.dumps(message.model_dump())) 1abcd

220 

221 if isinstance(message, AskingMessage): 1abcd

222 response = await websocket.receive_text() 1abcd

223 return response 

224 return None 1abcd

225 

226 return syncify(a_visit_default)(self, message) 1abcd

227 

228 def create_subconversation(self) -> UIBase: 1aefghibcd

229 return self 

230 

231 @contextmanager 1aefghibcd

232 def create(self, app: Runnable, import_string: str) -> Iterator[None]: 1aefghibcd

233 raise NotImplementedError("create") 

234 

235 def start( 1aefghibcd

236 self, 

237 *, 

238 app: "Runnable", 

239 import_string: str, 

240 name: Optional[str] = None, 

241 params: dict[str, Any], 

242 single_run: bool = False, 

243 ) -> None: 

244 raise NotImplementedError("start") 

245 

246 @classmethod 1aefghibcd

247 def create_provider( 1aefghibcd

248 cls, 

249 fastapi_url: str, 

250 ) -> ProviderProtocol: 

251 return FastAPIProvider( 

252 fastapi_url=fastapi_url, 

253 ) 

254 

255 

256class FastAPIProvider(ProviderProtocol): 1aefghibcd

257 def __init__( 1aefghibcd

258 self, 

259 fastapi_url: str, 

260 initiate_workflow_path: str = "/fastagency/initiate_workflow", 

261 discovery_path: str = "/fastagency/discovery", 

262 ws_path: str = "/fastagency/ws", 

263 ) -> None: 

264 """Initialize the fastapi workflows.""" 

265 self._workflows: dict[ 

266 str, tuple[Callable[[WorkflowsProtocol, UIBase, str, str], str], str] 

267 ] = {} 

268 

269 self.fastapi_url = ( 

270 fastapi_url[:-1] if fastapi_url.endswith("/") else fastapi_url 

271 ) 

272 self.ws_url = "ws" + self.fastapi_url[4:] 

273 

274 self.is_broker_running: bool = False 

275 

276 self.initiate_workflow_path = initiate_workflow_path 

277 self.discovery_path = discovery_path 

278 self.ws_path = ws_path 

279 

280 def _send_initiate_chat_msg( 1aefghibcd

281 self, 

282 workflow_name: str, 

283 workflow_uuid: str, 

284 user_id: Optional[str], 

285 params: dict[str, Any], 

286 ) -> InitiateWorkflowModel: 

287 msg = InititateChatModel( 

288 workflow_name=workflow_name, 

289 workflow_uuid=workflow_uuid, 

290 user_id=user_id, 

291 params=params, 

292 ) 

293 

294 payload = msg.model_dump() 

295 

296 resp = requests.post( 

297 f"{self.fastapi_url}{self.initiate_workflow_path}", json=payload, timeout=5 

298 ) 

299 logger.info(f"Initiate chat response: {resp.json()}") 

300 retval = InitiateWorkflowModel(**resp.json()) 

301 return retval 

302 

303 async def _publish_websocket_message( 1aefghibcd

304 self, 

305 websocket: websockets.WebSocketClientProtocol, 

306 message: InputResponseModel, 

307 ) -> None: 

308 payload = message.model_dump_json() 

309 await websocket.send(payload) 

310 logger.info(f"Message sent to websocket ({websocket}): {message}") 

311 

312 async def _run_websocket_subscriber( 1aefghibcd

313 self, 

314 ui: UI, 

315 workflow_name: str, 

316 user_id: Optional[str], 

317 from_server_subject: str, 

318 to_server_subject: str, 

319 params: dict[str, Any], 

320 ) -> None: 

321 connect_url = f"{self.ws_url}{self.ws_path}" 

322 async with websockets.connect(connect_url) as websocket: 

323 init_workflow_msg = InitiateWorkflowModel( 

324 name=workflow_name, 

325 workflow_uuid=ui._workflow_uuid, 

326 user_id=user_id, 

327 params=params, 

328 ) 

329 await websocket.send(init_workflow_msg.model_dump_json()) 

330 

331 while True: 

332 response = await websocket.recv() 

333 response = ( 

334 response.decode() if isinstance(response, bytes) else response 

335 ) 

336 

337 logger.info(f"Received message: {response}") 

338 

339 msg = IOMessage.create(**json.loads(response)) 

340 

341 retval = await asyncify(ui.process_message)(msg) 

342 logger.info(f"Message {msg}: processed with response {retval}") 

343 

344 if isinstance(msg, AskingMessage): 

345 if retval is None: 

346 logger.warning( 

347 f"Message {msg}: response is None. Skipping response to websocket" 

348 ) 

349 else: 

350 await websocket.send(retval) 

351 logger.info( 

352 f"Message {msg}: response {retval} sent to websocket" 

353 ) 

354 

355 def run( 1aefghibcd

356 self, 

357 name: str, 

358 ui: UI, 

359 user_id: Optional[str] = None, 

360 **kwargs: Any, 

361 ) -> str: 

362 workflow_uuid = ui._workflow_uuid 

363 

364 initiate_workflow = self._send_initiate_chat_msg( 

365 name, workflow_uuid=workflow_uuid, user_id=user_id, params=kwargs 

366 ) 

367 user_id = initiate_workflow.user_id if initiate_workflow.user_id else "None" 

368 workflow_uuid = initiate_workflow.workflow_uuid.hex 

369 

370 _from_server_subject = f"chat.client.messages.{user_id}.{workflow_uuid}" 

371 _to_server_subject = f"chat.server.messages.{user_id}.{workflow_uuid}" 

372 

373 async def _setup_and_run() -> None: 

374 await self._run_websocket_subscriber( 

375 ui, 

376 name, 

377 user_id, 

378 _from_server_subject, 

379 _to_server_subject, 

380 kwargs, 

381 ) 

382 

383 async def run_lifespan() -> None: 

384 if not self.is_broker_running: 

385 self.is_broker_running = True 

386 await _setup_and_run() 

387 else: 

388 await _setup_and_run() 

389 

390 try: 

391 loop = asyncio.get_event_loop() 

392 except RuntimeError: 

393 loop = asyncio.new_event_loop() 

394 asyncio.set_event_loop(loop) 

395 

396 loop.run_until_complete(run_lifespan()) 

397 

398 return "FastAPIWorkflows.run() completed" 

399 

400 def _get_workflow_info(self) -> list[dict[str, str]]: 1aefghibcd

401 try: 

402 resp = requests.get(f"{self.fastapi_url}/{self.discovery_path}", timeout=15) 

403 except requests.exceptions.ConnectionError as e: 

404 raise FastAgencyFastAPIConnectionError( 

405 f"Unable to connect to FastAPI server at {self.fastapi_url}" 

406 ) from e 

407 if resp.status_code == 504: 

408 raise FastAgencyConnectionError(resp.json()["detail"]) 

409 elif resp.status_code == 404: 

410 raise FastAgencyKeyError(resp.json()["detail"]) 

411 return resp.json() # type: ignore [no-any-return] 

412 

413 def _get_names(self) -> list[str]: 1aefghibcd

414 return [workflow["name"] for workflow in self._get_workflow_info()] 

415 

416 def _get_description(self, name: str) -> str: 1aefghibcd

417 return next( 

418 workflow["description"] 

419 for workflow in self._get_workflow_info() 

420 if workflow["name"] == name 

421 ) 

422 

423 @property 1aefghibcd

424 def names(self) -> list[str]: 1aefghibcd

425 return self._get_names() 

426 

427 def get_description(self, name: str) -> str: 1aefghibcd

428 return self._get_description(name)