Coverage for pydantic_ai_slim/pydantic_ai/models/bedrock.py: 95.89%

214 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-28 17:27 +0000

1from __future__ import annotations 

2 

3import functools 

4import typing 

5from collections.abc import AsyncIterator, Iterable 

6from contextlib import asynccontextmanager 

7from dataclasses import dataclass, field 

8from datetime import datetime 

9from typing import TYPE_CHECKING, Generic, Literal, Union, cast, overload 

10 

11import anyio 

12import anyio.to_thread 

13from typing_extensions import ParamSpec, assert_never 

14 

15from pydantic_ai import _utils, result 

16from pydantic_ai.messages import ( 

17 AudioUrl, 

18 BinaryContent, 

19 DocumentUrl, 

20 ImageUrl, 

21 ModelMessage, 

22 ModelRequest, 

23 ModelResponse, 

24 ModelResponsePart, 

25 ModelResponseStreamEvent, 

26 RetryPromptPart, 

27 SystemPromptPart, 

28 TextPart, 

29 ToolCallPart, 

30 ToolReturnPart, 

31 UserPromptPart, 

32) 

33from pydantic_ai.models import Model, ModelRequestParameters, StreamedResponse, cached_async_http_client 

34from pydantic_ai.providers import Provider, infer_provider 

35from pydantic_ai.settings import ModelSettings 

36from pydantic_ai.tools import ToolDefinition 

37 

38if TYPE_CHECKING: 

39 from botocore.client import BaseClient 

40 from botocore.eventstream import EventStream 

41 from mypy_boto3_bedrock_runtime import BedrockRuntimeClient 

42 from mypy_boto3_bedrock_runtime.type_defs import ( 

43 ContentBlockOutputTypeDef, 

44 ContentBlockUnionTypeDef, 

45 ConverseResponseTypeDef, 

46 ConverseStreamMetadataEventTypeDef, 

47 ConverseStreamOutputTypeDef, 

48 ImageBlockTypeDef, 

49 InferenceConfigurationTypeDef, 

50 MessageUnionTypeDef, 

51 ToolChoiceTypeDef, 

52 ToolTypeDef, 

53 ) 

54 

55 

56LatestBedrockModelNames = Literal[ 

57 'amazon.titan-tg1-large', 

58 'amazon.titan-text-lite-v1', 

59 'amazon.titan-text-express-v1', 

60 'us.amazon.nova-pro-v1:0', 

61 'us.amazon.nova-lite-v1:0', 

62 'us.amazon.nova-micro-v1:0', 

63 'anthropic.claude-3-5-sonnet-20241022-v2:0', 

64 'us.anthropic.claude-3-5-sonnet-20241022-v2:0', 

65 'anthropic.claude-3-5-haiku-20241022-v1:0', 

66 'us.anthropic.claude-3-5-haiku-20241022-v1:0', 

67 'anthropic.claude-instant-v1', 

68 'anthropic.claude-v2:1', 

69 'anthropic.claude-v2', 

70 'anthropic.claude-3-sonnet-20240229-v1:0', 

71 'us.anthropic.claude-3-sonnet-20240229-v1:0', 

72 'anthropic.claude-3-haiku-20240307-v1:0', 

73 'us.anthropic.claude-3-haiku-20240307-v1:0', 

74 'anthropic.claude-3-opus-20240229-v1:0', 

75 'us.anthropic.claude-3-opus-20240229-v1:0', 

76 'anthropic.claude-3-5-sonnet-20240620-v1:0', 

77 'us.anthropic.claude-3-5-sonnet-20240620-v1:0', 

78 'anthropic.claude-3-7-sonnet-20250219-v1:0', 

79 'us.anthropic.claude-3-7-sonnet-20250219-v1:0', 

80 'cohere.command-text-v14', 

81 'cohere.command-r-v1:0', 

82 'cohere.command-r-plus-v1:0', 

83 'cohere.command-light-text-v14', 

84 'meta.llama3-8b-instruct-v1:0', 

85 'meta.llama3-70b-instruct-v1:0', 

86 'meta.llama3-1-8b-instruct-v1:0', 

87 'us.meta.llama3-1-8b-instruct-v1:0', 

88 'meta.llama3-1-70b-instruct-v1:0', 

89 'us.meta.llama3-1-70b-instruct-v1:0', 

90 'meta.llama3-1-405b-instruct-v1:0', 

91 'us.meta.llama3-2-11b-instruct-v1:0', 

92 'us.meta.llama3-2-90b-instruct-v1:0', 

93 'us.meta.llama3-2-1b-instruct-v1:0', 

94 'us.meta.llama3-2-3b-instruct-v1:0', 

95 'us.meta.llama3-3-70b-instruct-v1:0', 

96 'mistral.mistral-7b-instruct-v0:2', 

97 'mistral.mixtral-8x7b-instruct-v0:1', 

98 'mistral.mistral-large-2402-v1:0', 

99 'mistral.mistral-large-2407-v1:0', 

100] 

101"""Latest Bedrock models.""" 

102 

103BedrockModelName = Union[str, LatestBedrockModelNames] 

104"""Possible Bedrock model names. 

105 

106Since Bedrock supports a variety of date-stamped models, we explicitly list the latest models but allow any name in the type hints. 

107See [the Bedrock docs](https://docs.aws.amazon.com/bedrock/latest/userguide/models-supported.html) for a full list. 

108""" 

109 

110 

111P = ParamSpec('P') 

112T = typing.TypeVar('T') 

113 

114 

115class BedrockModelSettings(ModelSettings): 

116 """Settings for Bedrock models. 

117 

118 ALL FIELDS MUST BE `bedrock_` PREFIXED SO YOU CAN MERGE THEM WITH OTHER MODELS. 

119 """ 

120 

121 

122@dataclass(init=False) 

123class BedrockConverseModel(Model): 

124 """A model that uses the Bedrock Converse API.""" 

125 

126 client: BedrockRuntimeClient 

127 

128 _model_name: BedrockModelName = field(repr=False) 

129 _system: str = field(default='bedrock', repr=False) 

130 

131 @property 

132 def model_name(self) -> str: 

133 """The model name.""" 

134 return self._model_name 

135 

136 @property 

137 def system(self) -> str: 

138 """The system / model provider, ex: openai.""" 

139 return self._system 

140 

141 def __init__( 

142 self, 

143 model_name: BedrockModelName, 

144 *, 

145 provider: Literal['bedrock'] | Provider[BaseClient] = 'bedrock', 

146 ): 

147 """Initialize a Bedrock model. 

148 

149 Args: 

150 model_name: The name of the model to use. 

151 model_name: The name of the Bedrock model to use. List of model names available 

152 [here](https://docs.aws.amazon.com/bedrock/latest/userguide/models-supported.html). 

153 provider: The provider to use for authentication and API access. Can be either the string 

154 'bedrock' or an instance of `Provider[BaseClient]`. If not provided, a new provider will be 

155 created using the other parameters. 

156 """ 

157 self._model_name = model_name 

158 

159 if isinstance(provider, str): 

160 provider = infer_provider(provider) 

161 self.client = cast('BedrockRuntimeClient', provider.client) 

162 

163 def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[ToolTypeDef]: 

164 tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools] 

165 if model_request_parameters.result_tools: 

166 tools += [self._map_tool_definition(r) for r in model_request_parameters.result_tools] 

167 return tools 

168 

169 @staticmethod 

170 def _map_tool_definition(f: ToolDefinition) -> ToolTypeDef: 

171 return { 

172 'toolSpec': { 

173 'name': f.name, 

174 'description': f.description, 

175 'inputSchema': {'json': f.parameters_json_schema}, 

176 } 

177 } 

178 

179 @property 

180 def base_url(self) -> str: 

181 return str(self.client.meta.endpoint_url) 

182 

183 async def request( 

184 self, 

185 messages: list[ModelMessage], 

186 model_settings: ModelSettings | None, 

187 model_request_parameters: ModelRequestParameters, 

188 ) -> tuple[ModelResponse, result.Usage]: 

189 response = await self._messages_create(messages, False, model_settings, model_request_parameters) 

190 return await self._process_response(response) 

191 

192 @asynccontextmanager 

193 async def request_stream( 

194 self, 

195 messages: list[ModelMessage], 

196 model_settings: ModelSettings | None, 

197 model_request_parameters: ModelRequestParameters, 

198 ) -> AsyncIterator[StreamedResponse]: 

199 response = await self._messages_create(messages, True, model_settings, model_request_parameters) 

200 yield BedrockStreamedResponse(_model_name=self.model_name, _event_stream=response) 

201 

202 async def _process_response(self, response: ConverseResponseTypeDef) -> tuple[ModelResponse, result.Usage]: 

203 items: list[ModelResponsePart] = [] 

204 if message := response['output'].get('message'): 204 ↛ 218line 204 didn't jump to line 218 because the condition on line 204 was always true

205 for item in message['content']: 

206 if text := item.get('text'): 

207 items.append(TextPart(content=text)) 

208 else: 

209 tool_use = item.get('toolUse') 

210 assert tool_use is not None, f'Found a content that is not a text or tool use: {item}' 

211 items.append( 

212 ToolCallPart( 

213 tool_name=tool_use['name'], 

214 args=tool_use['input'], 

215 tool_call_id=tool_use['toolUseId'], 

216 ), 

217 ) 

218 usage = result.Usage( 

219 request_tokens=response['usage']['inputTokens'], 

220 response_tokens=response['usage']['outputTokens'], 

221 total_tokens=response['usage']['totalTokens'], 

222 ) 

223 return ModelResponse(items, model_name=self.model_name), usage 

224 

225 @overload 

226 async def _messages_create( 

227 self, 

228 messages: list[ModelMessage], 

229 stream: Literal[True], 

230 model_settings: ModelSettings | None, 

231 model_request_parameters: ModelRequestParameters, 

232 ) -> EventStream[ConverseStreamOutputTypeDef]: 

233 pass 

234 

235 @overload 

236 async def _messages_create( 

237 self, 

238 messages: list[ModelMessage], 

239 stream: Literal[False], 

240 model_settings: ModelSettings | None, 

241 model_request_parameters: ModelRequestParameters, 

242 ) -> ConverseResponseTypeDef: 

243 pass 

244 

245 async def _messages_create( 

246 self, 

247 messages: list[ModelMessage], 

248 stream: bool, 

249 model_settings: ModelSettings | None, 

250 model_request_parameters: ModelRequestParameters, 

251 ) -> ConverseResponseTypeDef | EventStream[ConverseStreamOutputTypeDef]: 

252 tools = self._get_tools(model_request_parameters) 

253 support_tools_choice = self.model_name.startswith(('anthropic', 'us.anthropic')) 

254 if not tools or not support_tools_choice: 

255 tool_choice: ToolChoiceTypeDef = {} 

256 elif not model_request_parameters.allow_text_result: 256 ↛ 257line 256 didn't jump to line 257 because the condition on line 256 was never true

257 tool_choice = {'any': {}} 

258 else: 

259 tool_choice = {'auto': {}} 

260 

261 system_prompt, bedrock_messages = await self._map_message(messages) 

262 inference_config = self._map_inference_config(model_settings) 

263 

264 params = { 

265 'modelId': self.model_name, 

266 'messages': bedrock_messages, 

267 'system': [{'text': system_prompt}], 

268 'inferenceConfig': inference_config, 

269 **( 

270 {'toolConfig': {'tools': tools, **({'toolChoice': tool_choice} if tool_choice else {})}} 

271 if tools 

272 else {} 

273 ), 

274 } 

275 

276 if stream: 

277 model_response = await anyio.to_thread.run_sync(functools.partial(self.client.converse_stream, **params)) 

278 model_response = model_response['stream'] 

279 else: 

280 model_response = await anyio.to_thread.run_sync(functools.partial(self.client.converse, **params)) 

281 return model_response 

282 

283 @staticmethod 

284 def _map_inference_config( 

285 model_settings: ModelSettings | None, 

286 ) -> InferenceConfigurationTypeDef: 

287 model_settings = model_settings or {} 

288 inference_config: InferenceConfigurationTypeDef = {} 

289 

290 if max_tokens := model_settings.get('max_tokens'): 

291 inference_config['maxTokens'] = max_tokens 

292 if temperature := model_settings.get('temperature'): 292 ↛ 293line 292 didn't jump to line 293 because the condition on line 292 was never true

293 inference_config['temperature'] = temperature 

294 if top_p := model_settings.get('top_p'): 

295 inference_config['topP'] = top_p 

296 # TODO(Marcelo): This is not included in model_settings yet. 

297 # if stop_sequences := model_settings.get('stop_sequences'): 

298 # inference_config['stopSequences'] = stop_sequences 

299 

300 return inference_config 

301 

302 async def _map_message(self, messages: list[ModelMessage]) -> tuple[str, list[MessageUnionTypeDef]]: 

303 """Just maps a `pydantic_ai.Message` to the Bedrock `MessageUnionTypeDef`.""" 

304 system_prompt: str = '' 

305 bedrock_messages: list[MessageUnionTypeDef] = [] 

306 for m in messages: 

307 if isinstance(m, ModelRequest): 

308 for part in m.parts: 

309 if isinstance(part, SystemPromptPart): 

310 system_prompt += part.content 

311 elif isinstance(part, UserPromptPart): 

312 bedrock_messages.extend(await self._map_user_prompt(part)) 

313 elif isinstance(part, ToolReturnPart): 

314 assert part.tool_call_id is not None 

315 bedrock_messages.append( 

316 { 

317 'role': 'user', 

318 'content': [ 

319 { 

320 'toolResult': { 

321 'toolUseId': part.tool_call_id, 

322 'content': [{'text': part.model_response_str()}], 

323 'status': 'success', 

324 } 

325 } 

326 ], 

327 } 

328 ) 

329 elif isinstance(part, RetryPromptPart): 

330 # TODO(Marcelo): We need to add a test here. 

331 if part.tool_name is None: # pragma: no cover 

332 bedrock_messages.append({'role': 'user', 'content': [{'text': part.model_response()}]}) 

333 else: 

334 assert part.tool_call_id is not None 

335 bedrock_messages.append( 

336 { 

337 'role': 'user', 

338 'content': [ 

339 { 

340 'toolResult': { 

341 'toolUseId': part.tool_call_id, 

342 'content': [{'text': part.model_response()}], 

343 'status': 'error', 

344 } 

345 } 

346 ], 

347 } 

348 ) 

349 elif isinstance(m, ModelResponse): 

350 content: list[ContentBlockOutputTypeDef] = [] 

351 for item in m.parts: 

352 if isinstance(item, TextPart): 

353 content.append({'text': item.content}) 

354 else: 

355 assert isinstance(item, ToolCallPart) 

356 content.append(self._map_tool_call(item)) 

357 bedrock_messages.append({'role': 'assistant', 'content': content}) 

358 else: 

359 assert_never(m) 

360 return system_prompt, bedrock_messages 

361 

362 @staticmethod 

363 async def _map_user_prompt(part: UserPromptPart) -> list[MessageUnionTypeDef]: 

364 content: list[ContentBlockUnionTypeDef] = [] 

365 if isinstance(part.content, str): 

366 content.append({'text': part.content}) 

367 else: 

368 document_count = 0 

369 for item in part.content: 

370 if isinstance(item, str): 

371 content.append({'text': item}) 

372 elif isinstance(item, BinaryContent): 

373 format = item.format 

374 if item.is_document: 

375 document_count += 1 

376 name = f'Document {document_count}' 

377 assert format in ('pdf', 'txt', 'csv', 'doc', 'docx', 'xls', 'xlsx', 'html', 'md') 

378 content.append({'document': {'name': name, 'format': format, 'source': {'bytes': item.data}}}) 

379 elif item.is_image: 

380 assert format in ('jpeg', 'png', 'gif', 'webp') 

381 content.append({'image': {'format': format, 'source': {'bytes': item.data}}}) 

382 else: 

383 raise NotImplementedError('Binary content is not supported yet.') 

384 elif isinstance(item, (ImageUrl, DocumentUrl)): 

385 response = await cached_async_http_client().get(item.url) 

386 response.raise_for_status() 

387 if item.kind == 'image-url': 

388 format = item.media_type.split('/')[1] 

389 assert format in ('jpeg', 'png', 'gif', 'webp'), f'Unsupported image format: {format}' 

390 image: ImageBlockTypeDef = {'format': format, 'source': {'bytes': response.content}} 

391 content.append({'image': image}) 

392 elif item.kind == 'document-url': 392 ↛ 369line 392 didn't jump to line 369 because the condition on line 392 was always true

393 document_count += 1 

394 name = f'Document {document_count}' 

395 data = response.content 

396 content.append({'document': {'name': name, 'format': item.format, 'source': {'bytes': data}}}) 

397 elif isinstance(item, AudioUrl): # pragma: no cover 

398 raise NotImplementedError('Audio is not supported yet.') 

399 else: 

400 assert_never(item) 

401 return [{'role': 'user', 'content': content}] 

402 

403 @staticmethod 

404 def _map_tool_call(t: ToolCallPart) -> ContentBlockOutputTypeDef: 

405 return { 

406 'toolUse': {'toolUseId': _utils.guard_tool_call_id(t=t), 'name': t.tool_name, 'input': t.args_as_dict()} 

407 } 

408 

409 

410@dataclass 

411class BedrockStreamedResponse(StreamedResponse): 

412 """Implementation of `StreamedResponse` for Bedrock models.""" 

413 

414 _model_name: BedrockModelName 

415 _event_stream: EventStream[ConverseStreamOutputTypeDef] 

416 _timestamp: datetime = field(default_factory=_utils.now_utc) 

417 

418 async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: 

419 """Return an async iterator of [`ModelResponseStreamEvent`][pydantic_ai.messages.ModelResponseStreamEvent]s. 

420 

421 This method should be implemented by subclasses to translate the vendor-specific stream of events into 

422 pydantic_ai-format events. 

423 """ 

424 chunk: ConverseStreamOutputTypeDef 

425 tool_id: str | None = None 

426 async for chunk in _AsyncIteratorWrapper(self._event_stream): 

427 # TODO(Marcelo): Switch this to `match` when we drop Python 3.9 support. 

428 if 'messageStart' in chunk: 

429 continue 

430 if 'messageStop' in chunk: 

431 continue 

432 if 'metadata' in chunk: 

433 if 'usage' in chunk['metadata']: 433 ↛ 435line 433 didn't jump to line 435 because the condition on line 433 was always true

434 self._usage += self._map_usage(chunk['metadata']) 

435 continue 

436 if 'contentBlockStart' in chunk: 

437 index = chunk['contentBlockStart']['contentBlockIndex'] 

438 start = chunk['contentBlockStart']['start'] 

439 if 'toolUse' in start: 439 ↛ 451line 439 didn't jump to line 451 because the condition on line 439 was always true

440 tool_use_start = start['toolUse'] 

441 tool_id = tool_use_start['toolUseId'] 

442 tool_name = tool_use_start['name'] 

443 maybe_event = self._parts_manager.handle_tool_call_delta( 

444 vendor_part_id=index, 

445 tool_name=tool_name, 

446 args=None, 

447 tool_call_id=tool_id, 

448 ) 

449 if maybe_event: 449 ↛ 450line 449 didn't jump to line 450 because the condition on line 449 was never true

450 yield maybe_event 

451 if 'contentBlockDelta' in chunk: 

452 index = chunk['contentBlockDelta']['contentBlockIndex'] 

453 delta = chunk['contentBlockDelta']['delta'] 

454 if 'text' in delta: 

455 yield self._parts_manager.handle_text_delta(vendor_part_id=index, content=delta['text']) 

456 if 'toolUse' in delta: 

457 tool_use = delta['toolUse'] 

458 maybe_event = self._parts_manager.handle_tool_call_delta( 

459 vendor_part_id=index, 

460 tool_name=tool_use.get('name'), 

461 args=tool_use.get('input'), 

462 tool_call_id=tool_id, 

463 ) 

464 if maybe_event: 464 ↛ 426line 464 didn't jump to line 426 because the condition on line 464 was always true

465 yield maybe_event 

466 

467 @property 

468 def timestamp(self) -> datetime: 

469 return self._timestamp 

470 

471 @property 

472 def model_name(self) -> str: 

473 """Get the model name of the response.""" 

474 return self._model_name 

475 

476 def _map_usage(self, metadata: ConverseStreamMetadataEventTypeDef) -> result.Usage: 

477 return result.Usage( 

478 request_tokens=metadata['usage']['inputTokens'], 

479 response_tokens=metadata['usage']['outputTokens'], 

480 total_tokens=metadata['usage']['totalTokens'], 

481 ) 

482 

483 

484class _AsyncIteratorWrapper(Generic[T]): 

485 """Wrap a synchronous iterator in an async iterator.""" 

486 

487 def __init__(self, sync_iterator: Iterable[T]): 

488 self.sync_iterator = iter(sync_iterator) 

489 

490 def __aiter__(self): 

491 return self 

492 

493 async def __anext__(self) -> T: 

494 try: 

495 # Run the synchronous next() call in a thread pool 

496 item = await anyio.to_thread.run_sync(next, self.sync_iterator) 

497 return item 

498 except RuntimeError as e: 

499 if type(e.__cause__) is StopIteration: 499 ↛ 502line 499 didn't jump to line 502 because the condition on line 499 was always true

500 raise StopAsyncIteration 

501 else: 

502 raise e