Coverage for pydantic_ai_slim/pydantic_ai/models/anthropic.py: 94.59%
210 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-28 17:27 +0000
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-28 17:27 +0000
1from __future__ import annotations as _annotations
3import base64
4import io
5from collections.abc import AsyncGenerator, AsyncIterable, AsyncIterator
6from contextlib import asynccontextmanager
7from dataclasses import dataclass, field
8from datetime import datetime, timezone
9from json import JSONDecodeError, loads as json_loads
10from typing import Any, Literal, Union, cast, overload
12from anthropic.types import DocumentBlockParam
13from typing_extensions import assert_never
15from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
16from .._utils import guard_tool_call_id as _guard_tool_call_id
17from ..messages import (
18 BinaryContent,
19 DocumentUrl,
20 ImageUrl,
21 ModelMessage,
22 ModelRequest,
23 ModelResponse,
24 ModelResponsePart,
25 ModelResponseStreamEvent,
26 RetryPromptPart,
27 SystemPromptPart,
28 TextPart,
29 ToolCallPart,
30 ToolReturnPart,
31 UserPromptPart,
32)
33from ..providers import Provider, infer_provider
34from ..settings import ModelSettings
35from ..tools import ToolDefinition
36from . import Model, ModelRequestParameters, StreamedResponse, cached_async_http_client, check_allow_model_requests
38try:
39 from anthropic import NOT_GIVEN, APIStatusError, AsyncAnthropic, AsyncStream
40 from anthropic.types import (
41 Base64PDFSourceParam,
42 ContentBlock,
43 ImageBlockParam,
44 Message as AnthropicMessage,
45 MessageParam,
46 MetadataParam,
47 PlainTextSourceParam,
48 RawContentBlockDeltaEvent,
49 RawContentBlockStartEvent,
50 RawContentBlockStopEvent,
51 RawMessageDeltaEvent,
52 RawMessageStartEvent,
53 RawMessageStopEvent,
54 RawMessageStreamEvent,
55 TextBlock,
56 TextBlockParam,
57 TextDelta,
58 ToolChoiceParam,
59 ToolParam,
60 ToolResultBlockParam,
61 ToolUseBlock,
62 ToolUseBlockParam,
63 )
64except ImportError as _import_error:
65 raise ImportError(
66 'Please install `anthropic` to use the Anthropic model, '
67 'you can use the `anthropic` optional group — `pip install "pydantic-ai-slim[anthropic]"`'
68 ) from _import_error
70LatestAnthropicModelNames = Literal[
71 'claude-3-7-sonnet-latest',
72 'claude-3-5-haiku-latest',
73 'claude-3-5-sonnet-latest',
74 'claude-3-opus-latest',
75]
76"""Latest Anthropic models."""
78AnthropicModelName = Union[str, LatestAnthropicModelNames]
79"""Possible Anthropic model names.
81Since Anthropic supports a variety of date-stamped models, we explicitly list the latest models but
82allow any name in the type hints.
83See [the Anthropic docs](https://docs.anthropic.com/en/docs/about-claude/models) for a full list.
84"""
87class AnthropicModelSettings(ModelSettings):
88 """Settings used for an Anthropic model request.
90 ALL FIELDS MUST BE `anthropic_` PREFIXED SO YOU CAN MERGE THEM WITH OTHER MODELS.
91 """
93 anthropic_metadata: MetadataParam
94 """An object describing metadata about the request.
96 Contains `user_id`, an external identifier for the user who is associated with the request."""
99@dataclass(init=False)
100class AnthropicModel(Model):
101 """A model that uses the Anthropic API.
103 Internally, this uses the [Anthropic Python client](https://github.com/anthropics/anthropic-sdk-python) to interact with the API.
105 Apart from `__init__`, all methods are private or match those of the base class.
107 !!! note
108 The `AnthropicModel` class does not yet support streaming responses.
109 We anticipate adding support for streaming responses in a near-term future release.
110 """
112 client: AsyncAnthropic = field(repr=False)
114 _model_name: AnthropicModelName = field(repr=False)
115 _system: str = field(default='anthropic', repr=False)
117 def __init__(
118 self,
119 model_name: AnthropicModelName,
120 *,
121 provider: Literal['anthropic'] | Provider[AsyncAnthropic] = 'anthropic',
122 ):
123 """Initialize an Anthropic model.
125 Args:
126 model_name: The name of the Anthropic model to use. List of model names available
127 [here](https://docs.anthropic.com/en/docs/about-claude/models).
128 provider: The provider to use for the Anthropic API. Can be either the string 'anthropic' or an
129 instance of `Provider[AsyncAnthropic]`. If not provided, the other parameters will be used.
130 """
131 self._model_name = model_name
133 if isinstance(provider, str):
134 provider = infer_provider(provider)
135 self.client = provider.client
137 @property
138 def base_url(self) -> str:
139 return str(self.client.base_url)
141 async def request(
142 self,
143 messages: list[ModelMessage],
144 model_settings: ModelSettings | None,
145 model_request_parameters: ModelRequestParameters,
146 ) -> tuple[ModelResponse, usage.Usage]:
147 check_allow_model_requests()
148 response = await self._messages_create(
149 messages, False, cast(AnthropicModelSettings, model_settings or {}), model_request_parameters
150 )
151 return self._process_response(response), _map_usage(response)
153 @asynccontextmanager
154 async def request_stream(
155 self,
156 messages: list[ModelMessage],
157 model_settings: ModelSettings | None,
158 model_request_parameters: ModelRequestParameters,
159 ) -> AsyncIterator[StreamedResponse]:
160 check_allow_model_requests()
161 response = await self._messages_create(
162 messages, True, cast(AnthropicModelSettings, model_settings or {}), model_request_parameters
163 )
164 async with response:
165 yield await self._process_streamed_response(response)
167 @property
168 def model_name(self) -> AnthropicModelName:
169 """The model name."""
170 return self._model_name
172 @property
173 def system(self) -> str:
174 """The system / model provider."""
175 return self._system
177 @overload
178 async def _messages_create(
179 self,
180 messages: list[ModelMessage],
181 stream: Literal[True],
182 model_settings: AnthropicModelSettings,
183 model_request_parameters: ModelRequestParameters,
184 ) -> AsyncStream[RawMessageStreamEvent]:
185 pass
187 @overload
188 async def _messages_create(
189 self,
190 messages: list[ModelMessage],
191 stream: Literal[False],
192 model_settings: AnthropicModelSettings,
193 model_request_parameters: ModelRequestParameters,
194 ) -> AnthropicMessage:
195 pass
197 async def _messages_create(
198 self,
199 messages: list[ModelMessage],
200 stream: bool,
201 model_settings: AnthropicModelSettings,
202 model_request_parameters: ModelRequestParameters,
203 ) -> AnthropicMessage | AsyncStream[RawMessageStreamEvent]:
204 # standalone function to make it easier to override
205 tools = self._get_tools(model_request_parameters)
206 tool_choice: ToolChoiceParam | None
208 if not tools:
209 tool_choice = None
210 else:
211 if not model_request_parameters.allow_text_result:
212 tool_choice = {'type': 'any'}
213 else:
214 tool_choice = {'type': 'auto'}
216 if (allow_parallel_tool_calls := model_settings.get('parallel_tool_calls')) is not None:
217 tool_choice['disable_parallel_tool_use'] = not allow_parallel_tool_calls
219 system_prompt, anthropic_messages = await self._map_message(messages)
221 try:
222 return await self.client.messages.create(
223 max_tokens=model_settings.get('max_tokens', 1024),
224 system=system_prompt or NOT_GIVEN,
225 messages=anthropic_messages,
226 model=self._model_name,
227 tools=tools or NOT_GIVEN,
228 tool_choice=tool_choice or NOT_GIVEN,
229 stream=stream,
230 temperature=model_settings.get('temperature', NOT_GIVEN),
231 top_p=model_settings.get('top_p', NOT_GIVEN),
232 timeout=model_settings.get('timeout', NOT_GIVEN),
233 metadata=model_settings.get('anthropic_metadata', NOT_GIVEN),
234 )
235 except APIStatusError as e:
236 if (status_code := e.status_code) >= 400: 236 ↛ 238line 236 didn't jump to line 238 because the condition on line 236 was always true
237 raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e
238 raise
240 def _process_response(self, response: AnthropicMessage) -> ModelResponse:
241 """Process a non-streamed response, and prepare a message to return."""
242 items: list[ModelResponsePart] = []
243 for item in response.content:
244 if isinstance(item, TextBlock):
245 items.append(TextPart(content=item.text))
246 else:
247 assert isinstance(item, ToolUseBlock), 'unexpected item type'
248 items.append(
249 ToolCallPart(
250 tool_name=item.name,
251 args=cast(dict[str, Any], item.input),
252 tool_call_id=item.id,
253 )
254 )
256 return ModelResponse(items, model_name=response.model)
258 async def _process_streamed_response(self, response: AsyncStream[RawMessageStreamEvent]) -> StreamedResponse:
259 peekable_response = _utils.PeekableAsyncStream(response)
260 first_chunk = await peekable_response.peek()
261 if isinstance(first_chunk, _utils.Unset): 261 ↛ 262line 261 didn't jump to line 262 because the condition on line 261 was never true
262 raise UnexpectedModelBehavior('Streamed response ended without content or tool calls')
264 # Since Anthropic doesn't provide a timestamp in the message, we'll use the current time
265 timestamp = datetime.now(tz=timezone.utc)
266 return AnthropicStreamedResponse(
267 _model_name=self._model_name, _response=peekable_response, _timestamp=timestamp
268 )
270 def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[ToolParam]:
271 tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools]
272 if model_request_parameters.result_tools:
273 tools += [self._map_tool_definition(r) for r in model_request_parameters.result_tools]
274 return tools
276 async def _map_message(self, messages: list[ModelMessage]) -> tuple[str, list[MessageParam]]:
277 """Just maps a `pydantic_ai.Message` to a `anthropic.types.MessageParam`."""
278 system_prompt: str = ''
279 anthropic_messages: list[MessageParam] = []
280 for m in messages:
281 if isinstance(m, ModelRequest):
282 user_content_params: list[
283 ToolResultBlockParam | TextBlockParam | ImageBlockParam | DocumentBlockParam
284 ] = []
285 for request_part in m.parts:
286 if isinstance(request_part, SystemPromptPart):
287 system_prompt += request_part.content
288 elif isinstance(request_part, UserPromptPart):
289 async for content in self._map_user_prompt(request_part):
290 user_content_params.append(content)
291 elif isinstance(request_part, ToolReturnPart):
292 tool_result_block_param = ToolResultBlockParam(
293 tool_use_id=_guard_tool_call_id(t=request_part),
294 type='tool_result',
295 content=request_part.model_response_str(),
296 is_error=False,
297 )
298 user_content_params.append(tool_result_block_param)
299 elif isinstance(request_part, RetryPromptPart): 299 ↛ 285line 299 didn't jump to line 285 because the condition on line 299 was always true
300 if request_part.tool_name is None: 300 ↛ 301line 300 didn't jump to line 301 because the condition on line 300 was never true
301 retry_param = TextBlockParam(type='text', text=request_part.model_response())
302 else:
303 retry_param = ToolResultBlockParam(
304 tool_use_id=_guard_tool_call_id(t=request_part),
305 type='tool_result',
306 content=request_part.model_response(),
307 is_error=True,
308 )
309 user_content_params.append(retry_param)
310 anthropic_messages.append(MessageParam(role='user', content=user_content_params))
311 elif isinstance(m, ModelResponse):
312 assistant_content_params: list[TextBlockParam | ToolUseBlockParam] = []
313 for response_part in m.parts:
314 if isinstance(response_part, TextPart):
315 assistant_content_params.append(TextBlockParam(text=response_part.content, type='text'))
316 else:
317 tool_use_block_param = ToolUseBlockParam(
318 id=_guard_tool_call_id(t=response_part),
319 type='tool_use',
320 name=response_part.tool_name,
321 input=response_part.args_as_dict(),
322 )
323 assistant_content_params.append(tool_use_block_param)
324 anthropic_messages.append(MessageParam(role='assistant', content=assistant_content_params))
325 else:
326 assert_never(m)
327 return system_prompt, anthropic_messages
329 @staticmethod
330 async def _map_user_prompt(
331 part: UserPromptPart,
332 ) -> AsyncGenerator[ImageBlockParam | TextBlockParam | DocumentBlockParam]:
333 if isinstance(part.content, str):
334 yield TextBlockParam(text=part.content, type='text')
335 else:
336 for item in part.content:
337 if isinstance(item, str):
338 yield TextBlockParam(text=item, type='text')
339 elif isinstance(item, BinaryContent):
340 if item.is_image: 340 ↛ 341line 340 didn't jump to line 341 because the condition on line 340 was never true
341 yield ImageBlockParam(
342 source={'data': io.BytesIO(item.data), 'media_type': item.media_type, 'type': 'base64'}, # type: ignore
343 type='image',
344 )
345 elif item.media_type == 'application/pdf':
346 yield DocumentBlockParam(
347 source=Base64PDFSourceParam(
348 data=io.BytesIO(item.data),
349 media_type='application/pdf',
350 type='base64',
351 ),
352 type='document',
353 )
354 else:
355 raise RuntimeError('Only images and PDFs are supported for binary content')
356 elif isinstance(item, ImageUrl):
357 try:
358 response = await cached_async_http_client().get(item.url)
359 response.raise_for_status()
360 yield ImageBlockParam(
361 source={
362 'data': io.BytesIO(response.content),
363 'media_type': item.media_type,
364 'type': 'base64',
365 },
366 type='image',
367 )
368 except ValueError:
369 # Download the file if can't find the mime type.
370 client = cached_async_http_client()
371 response = await client.get(item.url, follow_redirects=True)
372 response.raise_for_status()
373 base64_encoded = base64.b64encode(response.content).decode('utf-8')
374 if (mime_type := response.headers['Content-Type']) in (
375 'image/jpeg',
376 'image/png',
377 'image/gif',
378 'image/webp',
379 ):
380 yield ImageBlockParam(
381 source={'data': base64_encoded, 'media_type': mime_type, 'type': 'base64'},
382 type='image',
383 )
384 else: # pragma: no cover
385 raise RuntimeError(f'Unsupported image type: {mime_type}')
386 elif isinstance(item, DocumentUrl): 386 ↛ 406line 386 didn't jump to line 406 because the condition on line 386 was always true
387 response = await cached_async_http_client().get(item.url)
388 response.raise_for_status()
389 if item.media_type == 'application/pdf':
390 yield DocumentBlockParam(
391 source=Base64PDFSourceParam(
392 data=io.BytesIO(response.content),
393 media_type=item.media_type,
394 type='base64',
395 ),
396 type='document',
397 )
398 elif item.media_type == 'text/plain':
399 yield DocumentBlockParam(
400 source=PlainTextSourceParam(data=response.text, media_type=item.media_type, type='text'),
401 type='document',
402 )
403 else: # pragma: no cover
404 raise RuntimeError(f'Unsupported media type: {item.media_type}')
405 else:
406 raise RuntimeError(f'Unsupported content type: {type(item)}')
408 @staticmethod
409 def _map_tool_definition(f: ToolDefinition) -> ToolParam:
410 return {
411 'name': f.name,
412 'description': f.description,
413 'input_schema': f.parameters_json_schema,
414 }
417def _map_usage(message: AnthropicMessage | RawMessageStreamEvent) -> usage.Usage:
418 if isinstance(message, AnthropicMessage):
419 response_usage = message.usage
420 else:
421 if isinstance(message, RawMessageStartEvent):
422 response_usage = message.message.usage
423 elif isinstance(message, RawMessageDeltaEvent):
424 response_usage = message.usage
425 else:
426 # No usage information provided in:
427 # - RawMessageStopEvent
428 # - RawContentBlockStartEvent
429 # - RawContentBlockDeltaEvent
430 # - RawContentBlockStopEvent
431 response_usage = None
433 if response_usage is None:
434 return usage.Usage()
436 request_tokens = getattr(response_usage, 'input_tokens', None)
438 return usage.Usage(
439 # Usage coming from the RawMessageDeltaEvent doesn't have input token data, hence this getattr
440 request_tokens=request_tokens,
441 response_tokens=response_usage.output_tokens,
442 total_tokens=(request_tokens or 0) + response_usage.output_tokens,
443 )
446@dataclass
447class AnthropicStreamedResponse(StreamedResponse):
448 """Implementation of `StreamedResponse` for Anthropic models."""
450 _model_name: AnthropicModelName
451 _response: AsyncIterable[RawMessageStreamEvent]
452 _timestamp: datetime
454 async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
455 current_block: ContentBlock | None = None
456 current_json: str = ''
458 async for event in self._response:
459 self._usage += _map_usage(event)
461 if isinstance(event, RawContentBlockStartEvent):
462 current_block = event.content_block
463 if isinstance(current_block, TextBlock) and current_block.text:
464 yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=current_block.text)
465 elif isinstance(current_block, ToolUseBlock): 465 ↛ 458line 465 didn't jump to line 458 because the condition on line 465 was always true
466 maybe_event = self._parts_manager.handle_tool_call_delta(
467 vendor_part_id=current_block.id,
468 tool_name=current_block.name,
469 args=cast(dict[str, Any], current_block.input),
470 tool_call_id=current_block.id,
471 )
472 if maybe_event is not None: 472 ↛ 458line 472 didn't jump to line 458 because the condition on line 472 was always true
473 yield maybe_event
475 elif isinstance(event, RawContentBlockDeltaEvent):
476 if isinstance(event.delta, TextDelta): 476 ↛ 477line 476 didn't jump to line 477 because the condition on line 476 was never true
477 yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=event.delta.text)
478 elif ( 478 ↛ 458line 478 didn't jump to line 458 because the condition on line 478 was always true
479 current_block and event.delta.type == 'input_json_delta' and isinstance(current_block, ToolUseBlock)
480 ):
481 # Try to parse the JSON immediately, otherwise cache the value for later. This handles
482 # cases where the JSON is not currently valid but will be valid once we stream more tokens.
483 try:
484 parsed_args = json_loads(current_json + event.delta.partial_json)
485 current_json = ''
486 except JSONDecodeError:
487 current_json += event.delta.partial_json
488 continue
490 # For tool calls, we need to handle partial JSON updates
491 maybe_event = self._parts_manager.handle_tool_call_delta(
492 vendor_part_id=current_block.id,
493 tool_name='',
494 args=parsed_args,
495 tool_call_id=current_block.id,
496 )
497 if maybe_event is not None: 497 ↛ 458line 497 didn't jump to line 458 because the condition on line 497 was always true
498 yield maybe_event
500 elif isinstance(event, (RawContentBlockStopEvent, RawMessageStopEvent)):
501 current_block = None
503 @property
504 def model_name(self) -> AnthropicModelName:
505 """Get the model name of the response."""
506 return self._model_name
508 @property
509 def timestamp(self) -> datetime:
510 """Get the timestamp of the response."""
511 return self._timestamp