Coverage for pydantic_ai_slim/pydantic_ai/models/anthropic.py: 94.35%
164 statements
« prev ^ index » next coverage.py v7.6.7, created at 2025-01-25 16:43 +0000
« prev ^ index » next coverage.py v7.6.7, created at 2025-01-25 16:43 +0000
1from __future__ import annotations as _annotations
3from collections.abc import AsyncIterable, AsyncIterator
4from contextlib import asynccontextmanager
5from dataclasses import dataclass, field
6from datetime import datetime, timezone
7from json import JSONDecodeError, loads as json_loads
8from typing import Any, Literal, Union, cast, overload
10from httpx import AsyncClient as AsyncHTTPClient
11from typing_extensions import assert_never
13from .. import UnexpectedModelBehavior, _utils, usage
14from .._utils import guard_tool_call_id as _guard_tool_call_id
15from ..messages import (
16 ModelMessage,
17 ModelRequest,
18 ModelResponse,
19 ModelResponsePart,
20 ModelResponseStreamEvent,
21 RetryPromptPart,
22 SystemPromptPart,
23 TextPart,
24 ToolCallPart,
25 ToolReturnPart,
26 UserPromptPart,
27)
28from ..settings import ModelSettings
29from ..tools import ToolDefinition
30from . import (
31 AgentModel,
32 Model,
33 StreamedResponse,
34 cached_async_http_client,
35 check_allow_model_requests,
36)
38try:
39 from anthropic import NOT_GIVEN, AsyncAnthropic, AsyncStream
40 from anthropic.types import (
41 Message as AnthropicMessage,
42 MessageParam,
43 MetadataParam,
44 RawContentBlockDeltaEvent,
45 RawContentBlockStartEvent,
46 RawContentBlockStopEvent,
47 RawMessageDeltaEvent,
48 RawMessageStartEvent,
49 RawMessageStopEvent,
50 RawMessageStreamEvent,
51 TextBlock,
52 TextBlockParam,
53 TextDelta,
54 ToolChoiceParam,
55 ToolParam,
56 ToolResultBlockParam,
57 ToolUseBlock,
58 ToolUseBlockParam,
59 )
60except ImportError as _import_error:
61 raise ImportError(
62 'Please install `anthropic` to use the Anthropic model, '
63 "you can use the `anthropic` optional group — `pip install 'pydantic-ai-slim[anthropic]'`"
64 ) from _import_error
66LatestAnthropicModelNames = Literal[
67 'claude-3-5-haiku-latest',
68 'claude-3-5-sonnet-latest',
69 'claude-3-opus-latest',
70]
71"""Latest named Anthropic models."""
73AnthropicModelName = Union[str, LatestAnthropicModelNames]
74"""Possible Anthropic model names.
76Since Anthropic supports a variety of date-stamped models, we explicitly list the latest models but
77allow any name in the type hints.
78Since [the Anthropic docs](https://docs.anthropic.com/en/docs/about-claude/models) for a full list.
79"""
82class AnthropicModelSettings(ModelSettings):
83 """Settings used for an Anthropic model request."""
85 anthropic_metadata: MetadataParam
86 """An object describing metadata about the request.
88 Contains `user_id`, an external identifier for the user who is associated with the request."""
91@dataclass(init=False)
92class AnthropicModel(Model):
93 """A model that uses the Anthropic API.
95 Internally, this uses the [Anthropic Python client](https://github.com/anthropics/anthropic-sdk-python) to interact with the API.
97 Apart from `__init__`, all methods are private or match those of the base class.
99 !!! note
100 The `AnthropicModel` class does not yet support streaming responses.
101 We anticipate adding support for streaming responses in a near-term future release.
102 """
104 model_name: AnthropicModelName
105 client: AsyncAnthropic = field(repr=False)
107 def __init__(
108 self,
109 model_name: AnthropicModelName,
110 *,
111 api_key: str | None = None,
112 anthropic_client: AsyncAnthropic | None = None,
113 http_client: AsyncHTTPClient | None = None,
114 ):
115 """Initialize an Anthropic model.
117 Args:
118 model_name: The name of the Anthropic model to use. List of model names available
119 [here](https://docs.anthropic.com/en/docs/about-claude/models).
120 api_key: The API key to use for authentication, if not provided, the `ANTHROPIC_API_KEY` environment variable
121 will be used if available.
122 anthropic_client: An existing
123 [`AsyncAnthropic`](https://github.com/anthropics/anthropic-sdk-python?tab=readme-ov-file#async-usage)
124 client to use, if provided, `api_key` and `http_client` must be `None`.
125 http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
126 """
127 self.model_name = model_name
128 if anthropic_client is not None:
129 assert http_client is None, 'Cannot provide both `anthropic_client` and `http_client`'
130 assert api_key is None, 'Cannot provide both `anthropic_client` and `api_key`'
131 self.client = anthropic_client
132 elif http_client is not None: 132 ↛ 133line 132 didn't jump to line 133 because the condition on line 132 was never true
133 self.client = AsyncAnthropic(api_key=api_key, http_client=http_client)
134 else:
135 self.client = AsyncAnthropic(api_key=api_key, http_client=cached_async_http_client())
137 async def agent_model(
138 self,
139 *,
140 function_tools: list[ToolDefinition],
141 allow_text_result: bool,
142 result_tools: list[ToolDefinition],
143 ) -> AgentModel:
144 check_allow_model_requests()
145 tools = [self._map_tool_definition(r) for r in function_tools]
146 if result_tools:
147 tools += [self._map_tool_definition(r) for r in result_tools]
148 return AnthropicAgentModel(
149 self.client,
150 self.model_name,
151 allow_text_result,
152 tools,
153 )
155 def name(self) -> str:
156 return f'anthropic:{self.model_name}'
158 @staticmethod
159 def _map_tool_definition(f: ToolDefinition) -> ToolParam:
160 return {
161 'name': f.name,
162 'description': f.description,
163 'input_schema': f.parameters_json_schema,
164 }
167@dataclass
168class AnthropicAgentModel(AgentModel):
169 """Implementation of `AgentModel` for Anthropic models."""
171 client: AsyncAnthropic
172 model_name: AnthropicModelName
173 allow_text_result: bool
174 tools: list[ToolParam]
176 async def request(
177 self, messages: list[ModelMessage], model_settings: ModelSettings | None
178 ) -> tuple[ModelResponse, usage.Usage]:
179 response = await self._messages_create(messages, False, cast(AnthropicModelSettings, model_settings or {}))
180 return self._process_response(response), _map_usage(response)
182 @asynccontextmanager
183 async def request_stream(
184 self, messages: list[ModelMessage], model_settings: ModelSettings | None
185 ) -> AsyncIterator[StreamedResponse]:
186 response = await self._messages_create(messages, True, cast(AnthropicModelSettings, model_settings or {}))
187 async with response:
188 yield await self._process_streamed_response(response)
190 @overload
191 async def _messages_create(
192 self, messages: list[ModelMessage], stream: Literal[True], model_settings: AnthropicModelSettings
193 ) -> AsyncStream[RawMessageStreamEvent]:
194 pass
196 @overload
197 async def _messages_create(
198 self, messages: list[ModelMessage], stream: Literal[False], model_settings: AnthropicModelSettings
199 ) -> AnthropicMessage:
200 pass
202 async def _messages_create(
203 self, messages: list[ModelMessage], stream: bool, model_settings: AnthropicModelSettings
204 ) -> AnthropicMessage | AsyncStream[RawMessageStreamEvent]:
205 # standalone function to make it easier to override
206 tool_choice: ToolChoiceParam | None
208 if not self.tools:
209 tool_choice = None
210 else:
211 if not self.allow_text_result:
212 tool_choice = {'type': 'any'}
213 else:
214 tool_choice = {'type': 'auto'}
216 if (allow_parallel_tool_calls := model_settings.get('parallel_tool_calls')) is not None:
217 tool_choice['disable_parallel_tool_use'] = not allow_parallel_tool_calls
219 system_prompt, anthropic_messages = self._map_message(messages)
221 return await self.client.messages.create(
222 max_tokens=model_settings.get('max_tokens', 1024),
223 system=system_prompt or NOT_GIVEN,
224 messages=anthropic_messages,
225 model=self.model_name,
226 tools=self.tools or NOT_GIVEN,
227 tool_choice=tool_choice or NOT_GIVEN,
228 stream=stream,
229 temperature=model_settings.get('temperature', NOT_GIVEN),
230 top_p=model_settings.get('top_p', NOT_GIVEN),
231 timeout=model_settings.get('timeout', NOT_GIVEN),
232 metadata=model_settings.get('anthropic_metadata', NOT_GIVEN),
233 )
235 def _process_response(self, response: AnthropicMessage) -> ModelResponse:
236 """Process a non-streamed response, and prepare a message to return."""
237 items: list[ModelResponsePart] = []
238 for item in response.content:
239 if isinstance(item, TextBlock):
240 items.append(TextPart(content=item.text))
241 else:
242 assert isinstance(item, ToolUseBlock), 'unexpected item type'
243 items.append(
244 ToolCallPart(
245 tool_name=item.name,
246 args=cast(dict[str, Any], item.input),
247 tool_call_id=item.id,
248 )
249 )
251 return ModelResponse(items, model_name=self.model_name)
253 async def _process_streamed_response(self, response: AsyncStream[RawMessageStreamEvent]) -> StreamedResponse:
254 peekable_response = _utils.PeekableAsyncStream(response)
255 first_chunk = await peekable_response.peek()
256 if isinstance(first_chunk, _utils.Unset): 256 ↛ 257line 256 didn't jump to line 257 because the condition on line 256 was never true
257 raise UnexpectedModelBehavior('Streamed response ended without content or tool calls')
259 # Since Anthropic doesn't provide a timestamp in the message, we'll use the current time
260 timestamp = datetime.now(tz=timezone.utc)
261 return AnthropicStreamedResponse(_model_name=self.model_name, _response=peekable_response, _timestamp=timestamp)
263 @staticmethod
264 def _map_message(messages: list[ModelMessage]) -> tuple[str, list[MessageParam]]:
265 """Just maps a `pydantic_ai.Message` to a `anthropic.types.MessageParam`."""
266 system_prompt: str = ''
267 anthropic_messages: list[MessageParam] = []
268 for m in messages:
269 if isinstance(m, ModelRequest):
270 for part in m.parts:
271 if isinstance(part, SystemPromptPart):
272 system_prompt += part.content
273 elif isinstance(part, UserPromptPart):
274 anthropic_messages.append(MessageParam(role='user', content=part.content))
275 elif isinstance(part, ToolReturnPart):
276 anthropic_messages.append(
277 MessageParam(
278 role='user',
279 content=[
280 ToolResultBlockParam(
281 tool_use_id=_guard_tool_call_id(t=part, model_source='Anthropic'),
282 type='tool_result',
283 content=part.model_response_str(),
284 is_error=False,
285 )
286 ],
287 )
288 )
289 elif isinstance(part, RetryPromptPart): 289 ↛ 270line 289 didn't jump to line 270 because the condition on line 289 was always true
290 if part.tool_name is None: 290 ↛ 291line 290 didn't jump to line 291 because the condition on line 290 was never true
291 anthropic_messages.append(MessageParam(role='user', content=part.model_response()))
292 else:
293 anthropic_messages.append(
294 MessageParam(
295 role='user',
296 content=[
297 ToolResultBlockParam(
298 tool_use_id=_guard_tool_call_id(t=part, model_source='Anthropic'),
299 type='tool_result',
300 content=part.model_response(),
301 is_error=True,
302 ),
303 ],
304 )
305 )
306 elif isinstance(m, ModelResponse):
307 content: list[TextBlockParam | ToolUseBlockParam] = []
308 for item in m.parts:
309 if isinstance(item, TextPart):
310 content.append(TextBlockParam(text=item.content, type='text'))
311 else:
312 assert isinstance(item, ToolCallPart)
313 content.append(_map_tool_call(item))
314 anthropic_messages.append(MessageParam(role='assistant', content=content))
315 else:
316 assert_never(m)
317 return system_prompt, anthropic_messages
320def _map_tool_call(t: ToolCallPart) -> ToolUseBlockParam:
321 return ToolUseBlockParam(
322 id=_guard_tool_call_id(t=t, model_source='Anthropic'),
323 type='tool_use',
324 name=t.tool_name,
325 input=t.args_as_dict(),
326 )
329def _map_usage(message: AnthropicMessage | RawMessageStreamEvent) -> usage.Usage:
330 if isinstance(message, AnthropicMessage):
331 response_usage = message.usage
332 else:
333 if isinstance(message, RawMessageStartEvent):
334 response_usage = message.message.usage
335 elif isinstance(message, RawMessageDeltaEvent):
336 response_usage = message.usage
337 else:
338 # No usage information provided in:
339 # - RawMessageStopEvent
340 # - RawContentBlockStartEvent
341 # - RawContentBlockDeltaEvent
342 # - RawContentBlockStopEvent
343 response_usage = None
345 if response_usage is None:
346 return usage.Usage()
348 request_tokens = getattr(response_usage, 'input_tokens', None)
350 return usage.Usage(
351 # Usage coming from the RawMessageDeltaEvent doesn't have input token data, hence this getattr
352 request_tokens=request_tokens,
353 response_tokens=response_usage.output_tokens,
354 total_tokens=(request_tokens or 0) + response_usage.output_tokens,
355 )
358@dataclass
359class AnthropicStreamedResponse(StreamedResponse):
360 """Implementation of `StreamedResponse` for Anthropic models."""
362 _response: AsyncIterable[RawMessageStreamEvent]
363 _timestamp: datetime
365 async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
366 current_block: TextBlock | ToolUseBlock | None = None
367 current_json: str = ''
369 async for event in self._response:
370 self._usage += _map_usage(event)
372 if isinstance(event, RawContentBlockStartEvent):
373 current_block = event.content_block
374 if isinstance(current_block, TextBlock) and current_block.text:
375 yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=current_block.text)
376 elif isinstance(current_block, ToolUseBlock): 376 ↛ 369line 376 didn't jump to line 369 because the condition on line 376 was always true
377 maybe_event = self._parts_manager.handle_tool_call_delta(
378 vendor_part_id=current_block.id,
379 tool_name=current_block.name,
380 args=cast(dict[str, Any], current_block.input),
381 tool_call_id=current_block.id,
382 )
383 if maybe_event is not None: 383 ↛ 369line 383 didn't jump to line 369 because the condition on line 383 was always true
384 yield maybe_event
386 elif isinstance(event, RawContentBlockDeltaEvent):
387 if isinstance(event.delta, TextDelta): 387 ↛ 388line 387 didn't jump to line 388 because the condition on line 387 was never true
388 yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=event.delta.text)
389 elif ( 389 ↛ 369line 389 didn't jump to line 369
390 current_block and event.delta.type == 'input_json_delta' and isinstance(current_block, ToolUseBlock)
391 ):
392 # Try to parse the JSON immediately, otherwise cache the value for later. This handles
393 # cases where the JSON is not currently valid but will be valid once we stream more tokens.
394 try:
395 parsed_args = json_loads(current_json + event.delta.partial_json)
396 current_json = ''
397 except JSONDecodeError:
398 current_json += event.delta.partial_json
399 continue
401 # For tool calls, we need to handle partial JSON updates
402 maybe_event = self._parts_manager.handle_tool_call_delta(
403 vendor_part_id=current_block.id,
404 tool_name='',
405 args=parsed_args,
406 tool_call_id=current_block.id,
407 )
408 if maybe_event is not None: 408 ↛ 369line 408 didn't jump to line 369 because the condition on line 408 was always true
409 yield maybe_event
411 elif isinstance(event, (RawContentBlockStopEvent, RawMessageStopEvent)):
412 current_block = None
414 def timestamp(self) -> datetime:
415 return self._timestamp