Coverage for pydantic_ai_slim/pydantic_ai/models/bedrock.py: 95.89%
214 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-28 17:27 +0000
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-28 17:27 +0000
1from __future__ import annotations
3import functools
4import typing
5from collections.abc import AsyncIterator, Iterable
6from contextlib import asynccontextmanager
7from dataclasses import dataclass, field
8from datetime import datetime
9from typing import TYPE_CHECKING, Generic, Literal, Union, cast, overload
11import anyio
12import anyio.to_thread
13from typing_extensions import ParamSpec, assert_never
15from pydantic_ai import _utils, result
16from pydantic_ai.messages import (
17 AudioUrl,
18 BinaryContent,
19 DocumentUrl,
20 ImageUrl,
21 ModelMessage,
22 ModelRequest,
23 ModelResponse,
24 ModelResponsePart,
25 ModelResponseStreamEvent,
26 RetryPromptPart,
27 SystemPromptPart,
28 TextPart,
29 ToolCallPart,
30 ToolReturnPart,
31 UserPromptPart,
32)
33from pydantic_ai.models import Model, ModelRequestParameters, StreamedResponse, cached_async_http_client
34from pydantic_ai.providers import Provider, infer_provider
35from pydantic_ai.settings import ModelSettings
36from pydantic_ai.tools import ToolDefinition
38if TYPE_CHECKING:
39 from botocore.client import BaseClient
40 from botocore.eventstream import EventStream
41 from mypy_boto3_bedrock_runtime import BedrockRuntimeClient
42 from mypy_boto3_bedrock_runtime.type_defs import (
43 ContentBlockOutputTypeDef,
44 ContentBlockUnionTypeDef,
45 ConverseResponseTypeDef,
46 ConverseStreamMetadataEventTypeDef,
47 ConverseStreamOutputTypeDef,
48 ImageBlockTypeDef,
49 InferenceConfigurationTypeDef,
50 MessageUnionTypeDef,
51 ToolChoiceTypeDef,
52 ToolTypeDef,
53 )
56LatestBedrockModelNames = Literal[
57 'amazon.titan-tg1-large',
58 'amazon.titan-text-lite-v1',
59 'amazon.titan-text-express-v1',
60 'us.amazon.nova-pro-v1:0',
61 'us.amazon.nova-lite-v1:0',
62 'us.amazon.nova-micro-v1:0',
63 'anthropic.claude-3-5-sonnet-20241022-v2:0',
64 'us.anthropic.claude-3-5-sonnet-20241022-v2:0',
65 'anthropic.claude-3-5-haiku-20241022-v1:0',
66 'us.anthropic.claude-3-5-haiku-20241022-v1:0',
67 'anthropic.claude-instant-v1',
68 'anthropic.claude-v2:1',
69 'anthropic.claude-v2',
70 'anthropic.claude-3-sonnet-20240229-v1:0',
71 'us.anthropic.claude-3-sonnet-20240229-v1:0',
72 'anthropic.claude-3-haiku-20240307-v1:0',
73 'us.anthropic.claude-3-haiku-20240307-v1:0',
74 'anthropic.claude-3-opus-20240229-v1:0',
75 'us.anthropic.claude-3-opus-20240229-v1:0',
76 'anthropic.claude-3-5-sonnet-20240620-v1:0',
77 'us.anthropic.claude-3-5-sonnet-20240620-v1:0',
78 'anthropic.claude-3-7-sonnet-20250219-v1:0',
79 'us.anthropic.claude-3-7-sonnet-20250219-v1:0',
80 'cohere.command-text-v14',
81 'cohere.command-r-v1:0',
82 'cohere.command-r-plus-v1:0',
83 'cohere.command-light-text-v14',
84 'meta.llama3-8b-instruct-v1:0',
85 'meta.llama3-70b-instruct-v1:0',
86 'meta.llama3-1-8b-instruct-v1:0',
87 'us.meta.llama3-1-8b-instruct-v1:0',
88 'meta.llama3-1-70b-instruct-v1:0',
89 'us.meta.llama3-1-70b-instruct-v1:0',
90 'meta.llama3-1-405b-instruct-v1:0',
91 'us.meta.llama3-2-11b-instruct-v1:0',
92 'us.meta.llama3-2-90b-instruct-v1:0',
93 'us.meta.llama3-2-1b-instruct-v1:0',
94 'us.meta.llama3-2-3b-instruct-v1:0',
95 'us.meta.llama3-3-70b-instruct-v1:0',
96 'mistral.mistral-7b-instruct-v0:2',
97 'mistral.mixtral-8x7b-instruct-v0:1',
98 'mistral.mistral-large-2402-v1:0',
99 'mistral.mistral-large-2407-v1:0',
100]
101"""Latest Bedrock models."""
103BedrockModelName = Union[str, LatestBedrockModelNames]
104"""Possible Bedrock model names.
106Since Bedrock supports a variety of date-stamped models, we explicitly list the latest models but allow any name in the type hints.
107See [the Bedrock docs](https://docs.aws.amazon.com/bedrock/latest/userguide/models-supported.html) for a full list.
108"""
111P = ParamSpec('P')
112T = typing.TypeVar('T')
115class BedrockModelSettings(ModelSettings):
116 """Settings for Bedrock models.
118 ALL FIELDS MUST BE `bedrock_` PREFIXED SO YOU CAN MERGE THEM WITH OTHER MODELS.
119 """
122@dataclass(init=False)
123class BedrockConverseModel(Model):
124 """A model that uses the Bedrock Converse API."""
126 client: BedrockRuntimeClient
128 _model_name: BedrockModelName = field(repr=False)
129 _system: str = field(default='bedrock', repr=False)
131 @property
132 def model_name(self) -> str:
133 """The model name."""
134 return self._model_name
136 @property
137 def system(self) -> str:
138 """The system / model provider, ex: openai."""
139 return self._system
141 def __init__(
142 self,
143 model_name: BedrockModelName,
144 *,
145 provider: Literal['bedrock'] | Provider[BaseClient] = 'bedrock',
146 ):
147 """Initialize a Bedrock model.
149 Args:
150 model_name: The name of the model to use.
151 model_name: The name of the Bedrock model to use. List of model names available
152 [here](https://docs.aws.amazon.com/bedrock/latest/userguide/models-supported.html).
153 provider: The provider to use for authentication and API access. Can be either the string
154 'bedrock' or an instance of `Provider[BaseClient]`. If not provided, a new provider will be
155 created using the other parameters.
156 """
157 self._model_name = model_name
159 if isinstance(provider, str):
160 provider = infer_provider(provider)
161 self.client = cast('BedrockRuntimeClient', provider.client)
163 def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[ToolTypeDef]:
164 tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools]
165 if model_request_parameters.result_tools:
166 tools += [self._map_tool_definition(r) for r in model_request_parameters.result_tools]
167 return tools
169 @staticmethod
170 def _map_tool_definition(f: ToolDefinition) -> ToolTypeDef:
171 return {
172 'toolSpec': {
173 'name': f.name,
174 'description': f.description,
175 'inputSchema': {'json': f.parameters_json_schema},
176 }
177 }
179 @property
180 def base_url(self) -> str:
181 return str(self.client.meta.endpoint_url)
183 async def request(
184 self,
185 messages: list[ModelMessage],
186 model_settings: ModelSettings | None,
187 model_request_parameters: ModelRequestParameters,
188 ) -> tuple[ModelResponse, result.Usage]:
189 response = await self._messages_create(messages, False, model_settings, model_request_parameters)
190 return await self._process_response(response)
192 @asynccontextmanager
193 async def request_stream(
194 self,
195 messages: list[ModelMessage],
196 model_settings: ModelSettings | None,
197 model_request_parameters: ModelRequestParameters,
198 ) -> AsyncIterator[StreamedResponse]:
199 response = await self._messages_create(messages, True, model_settings, model_request_parameters)
200 yield BedrockStreamedResponse(_model_name=self.model_name, _event_stream=response)
202 async def _process_response(self, response: ConverseResponseTypeDef) -> tuple[ModelResponse, result.Usage]:
203 items: list[ModelResponsePart] = []
204 if message := response['output'].get('message'): 204 ↛ 218line 204 didn't jump to line 218 because the condition on line 204 was always true
205 for item in message['content']:
206 if text := item.get('text'):
207 items.append(TextPart(content=text))
208 else:
209 tool_use = item.get('toolUse')
210 assert tool_use is not None, f'Found a content that is not a text or tool use: {item}'
211 items.append(
212 ToolCallPart(
213 tool_name=tool_use['name'],
214 args=tool_use['input'],
215 tool_call_id=tool_use['toolUseId'],
216 ),
217 )
218 usage = result.Usage(
219 request_tokens=response['usage']['inputTokens'],
220 response_tokens=response['usage']['outputTokens'],
221 total_tokens=response['usage']['totalTokens'],
222 )
223 return ModelResponse(items, model_name=self.model_name), usage
225 @overload
226 async def _messages_create(
227 self,
228 messages: list[ModelMessage],
229 stream: Literal[True],
230 model_settings: ModelSettings | None,
231 model_request_parameters: ModelRequestParameters,
232 ) -> EventStream[ConverseStreamOutputTypeDef]:
233 pass
235 @overload
236 async def _messages_create(
237 self,
238 messages: list[ModelMessage],
239 stream: Literal[False],
240 model_settings: ModelSettings | None,
241 model_request_parameters: ModelRequestParameters,
242 ) -> ConverseResponseTypeDef:
243 pass
245 async def _messages_create(
246 self,
247 messages: list[ModelMessage],
248 stream: bool,
249 model_settings: ModelSettings | None,
250 model_request_parameters: ModelRequestParameters,
251 ) -> ConverseResponseTypeDef | EventStream[ConverseStreamOutputTypeDef]:
252 tools = self._get_tools(model_request_parameters)
253 support_tools_choice = self.model_name.startswith(('anthropic', 'us.anthropic'))
254 if not tools or not support_tools_choice:
255 tool_choice: ToolChoiceTypeDef = {}
256 elif not model_request_parameters.allow_text_result: 256 ↛ 257line 256 didn't jump to line 257 because the condition on line 256 was never true
257 tool_choice = {'any': {}}
258 else:
259 tool_choice = {'auto': {}}
261 system_prompt, bedrock_messages = await self._map_message(messages)
262 inference_config = self._map_inference_config(model_settings)
264 params = {
265 'modelId': self.model_name,
266 'messages': bedrock_messages,
267 'system': [{'text': system_prompt}],
268 'inferenceConfig': inference_config,
269 **(
270 {'toolConfig': {'tools': tools, **({'toolChoice': tool_choice} if tool_choice else {})}}
271 if tools
272 else {}
273 ),
274 }
276 if stream:
277 model_response = await anyio.to_thread.run_sync(functools.partial(self.client.converse_stream, **params))
278 model_response = model_response['stream']
279 else:
280 model_response = await anyio.to_thread.run_sync(functools.partial(self.client.converse, **params))
281 return model_response
283 @staticmethod
284 def _map_inference_config(
285 model_settings: ModelSettings | None,
286 ) -> InferenceConfigurationTypeDef:
287 model_settings = model_settings or {}
288 inference_config: InferenceConfigurationTypeDef = {}
290 if max_tokens := model_settings.get('max_tokens'):
291 inference_config['maxTokens'] = max_tokens
292 if temperature := model_settings.get('temperature'): 292 ↛ 293line 292 didn't jump to line 293 because the condition on line 292 was never true
293 inference_config['temperature'] = temperature
294 if top_p := model_settings.get('top_p'):
295 inference_config['topP'] = top_p
296 # TODO(Marcelo): This is not included in model_settings yet.
297 # if stop_sequences := model_settings.get('stop_sequences'):
298 # inference_config['stopSequences'] = stop_sequences
300 return inference_config
302 async def _map_message(self, messages: list[ModelMessage]) -> tuple[str, list[MessageUnionTypeDef]]:
303 """Just maps a `pydantic_ai.Message` to the Bedrock `MessageUnionTypeDef`."""
304 system_prompt: str = ''
305 bedrock_messages: list[MessageUnionTypeDef] = []
306 for m in messages:
307 if isinstance(m, ModelRequest):
308 for part in m.parts:
309 if isinstance(part, SystemPromptPart):
310 system_prompt += part.content
311 elif isinstance(part, UserPromptPart):
312 bedrock_messages.extend(await self._map_user_prompt(part))
313 elif isinstance(part, ToolReturnPart):
314 assert part.tool_call_id is not None
315 bedrock_messages.append(
316 {
317 'role': 'user',
318 'content': [
319 {
320 'toolResult': {
321 'toolUseId': part.tool_call_id,
322 'content': [{'text': part.model_response_str()}],
323 'status': 'success',
324 }
325 }
326 ],
327 }
328 )
329 elif isinstance(part, RetryPromptPart):
330 # TODO(Marcelo): We need to add a test here.
331 if part.tool_name is None: # pragma: no cover
332 bedrock_messages.append({'role': 'user', 'content': [{'text': part.model_response()}]})
333 else:
334 assert part.tool_call_id is not None
335 bedrock_messages.append(
336 {
337 'role': 'user',
338 'content': [
339 {
340 'toolResult': {
341 'toolUseId': part.tool_call_id,
342 'content': [{'text': part.model_response()}],
343 'status': 'error',
344 }
345 }
346 ],
347 }
348 )
349 elif isinstance(m, ModelResponse):
350 content: list[ContentBlockOutputTypeDef] = []
351 for item in m.parts:
352 if isinstance(item, TextPart):
353 content.append({'text': item.content})
354 else:
355 assert isinstance(item, ToolCallPart)
356 content.append(self._map_tool_call(item))
357 bedrock_messages.append({'role': 'assistant', 'content': content})
358 else:
359 assert_never(m)
360 return system_prompt, bedrock_messages
362 @staticmethod
363 async def _map_user_prompt(part: UserPromptPart) -> list[MessageUnionTypeDef]:
364 content: list[ContentBlockUnionTypeDef] = []
365 if isinstance(part.content, str):
366 content.append({'text': part.content})
367 else:
368 document_count = 0
369 for item in part.content:
370 if isinstance(item, str):
371 content.append({'text': item})
372 elif isinstance(item, BinaryContent):
373 format = item.format
374 if item.is_document:
375 document_count += 1
376 name = f'Document {document_count}'
377 assert format in ('pdf', 'txt', 'csv', 'doc', 'docx', 'xls', 'xlsx', 'html', 'md')
378 content.append({'document': {'name': name, 'format': format, 'source': {'bytes': item.data}}})
379 elif item.is_image:
380 assert format in ('jpeg', 'png', 'gif', 'webp')
381 content.append({'image': {'format': format, 'source': {'bytes': item.data}}})
382 else:
383 raise NotImplementedError('Binary content is not supported yet.')
384 elif isinstance(item, (ImageUrl, DocumentUrl)):
385 response = await cached_async_http_client().get(item.url)
386 response.raise_for_status()
387 if item.kind == 'image-url':
388 format = item.media_type.split('/')[1]
389 assert format in ('jpeg', 'png', 'gif', 'webp'), f'Unsupported image format: {format}'
390 image: ImageBlockTypeDef = {'format': format, 'source': {'bytes': response.content}}
391 content.append({'image': image})
392 elif item.kind == 'document-url': 392 ↛ 369line 392 didn't jump to line 369 because the condition on line 392 was always true
393 document_count += 1
394 name = f'Document {document_count}'
395 data = response.content
396 content.append({'document': {'name': name, 'format': item.format, 'source': {'bytes': data}}})
397 elif isinstance(item, AudioUrl): # pragma: no cover
398 raise NotImplementedError('Audio is not supported yet.')
399 else:
400 assert_never(item)
401 return [{'role': 'user', 'content': content}]
403 @staticmethod
404 def _map_tool_call(t: ToolCallPart) -> ContentBlockOutputTypeDef:
405 return {
406 'toolUse': {'toolUseId': _utils.guard_tool_call_id(t=t), 'name': t.tool_name, 'input': t.args_as_dict()}
407 }
410@dataclass
411class BedrockStreamedResponse(StreamedResponse):
412 """Implementation of `StreamedResponse` for Bedrock models."""
414 _model_name: BedrockModelName
415 _event_stream: EventStream[ConverseStreamOutputTypeDef]
416 _timestamp: datetime = field(default_factory=_utils.now_utc)
418 async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
419 """Return an async iterator of [`ModelResponseStreamEvent`][pydantic_ai.messages.ModelResponseStreamEvent]s.
421 This method should be implemented by subclasses to translate the vendor-specific stream of events into
422 pydantic_ai-format events.
423 """
424 chunk: ConverseStreamOutputTypeDef
425 tool_id: str | None = None
426 async for chunk in _AsyncIteratorWrapper(self._event_stream):
427 # TODO(Marcelo): Switch this to `match` when we drop Python 3.9 support.
428 if 'messageStart' in chunk:
429 continue
430 if 'messageStop' in chunk:
431 continue
432 if 'metadata' in chunk:
433 if 'usage' in chunk['metadata']: 433 ↛ 435line 433 didn't jump to line 435 because the condition on line 433 was always true
434 self._usage += self._map_usage(chunk['metadata'])
435 continue
436 if 'contentBlockStart' in chunk:
437 index = chunk['contentBlockStart']['contentBlockIndex']
438 start = chunk['contentBlockStart']['start']
439 if 'toolUse' in start: 439 ↛ 451line 439 didn't jump to line 451 because the condition on line 439 was always true
440 tool_use_start = start['toolUse']
441 tool_id = tool_use_start['toolUseId']
442 tool_name = tool_use_start['name']
443 maybe_event = self._parts_manager.handle_tool_call_delta(
444 vendor_part_id=index,
445 tool_name=tool_name,
446 args=None,
447 tool_call_id=tool_id,
448 )
449 if maybe_event: 449 ↛ 450line 449 didn't jump to line 450 because the condition on line 449 was never true
450 yield maybe_event
451 if 'contentBlockDelta' in chunk:
452 index = chunk['contentBlockDelta']['contentBlockIndex']
453 delta = chunk['contentBlockDelta']['delta']
454 if 'text' in delta:
455 yield self._parts_manager.handle_text_delta(vendor_part_id=index, content=delta['text'])
456 if 'toolUse' in delta:
457 tool_use = delta['toolUse']
458 maybe_event = self._parts_manager.handle_tool_call_delta(
459 vendor_part_id=index,
460 tool_name=tool_use.get('name'),
461 args=tool_use.get('input'),
462 tool_call_id=tool_id,
463 )
464 if maybe_event: 464 ↛ 426line 464 didn't jump to line 426 because the condition on line 464 was always true
465 yield maybe_event
467 @property
468 def timestamp(self) -> datetime:
469 return self._timestamp
471 @property
472 def model_name(self) -> str:
473 """Get the model name of the response."""
474 return self._model_name
476 def _map_usage(self, metadata: ConverseStreamMetadataEventTypeDef) -> result.Usage:
477 return result.Usage(
478 request_tokens=metadata['usage']['inputTokens'],
479 response_tokens=metadata['usage']['outputTokens'],
480 total_tokens=metadata['usage']['totalTokens'],
481 )
484class _AsyncIteratorWrapper(Generic[T]):
485 """Wrap a synchronous iterator in an async iterator."""
487 def __init__(self, sync_iterator: Iterable[T]):
488 self.sync_iterator = iter(sync_iterator)
490 def __aiter__(self):
491 return self
493 async def __anext__(self) -> T:
494 try:
495 # Run the synchronous next() call in a thread pool
496 item = await anyio.to_thread.run_sync(next, self.sync_iterator)
497 return item
498 except RuntimeError as e:
499 if type(e.__cause__) is StopIteration: 499 ↛ 502line 499 didn't jump to line 502 because the condition on line 499 was always true
500 raise StopAsyncIteration
501 else:
502 raise e