Coverage for pydantic_ai_slim/pydantic_ai/models/anthropic.py: 94.35%

164 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2025-01-25 16:43 +0000

1from __future__ import annotations as _annotations 

2 

3from collections.abc import AsyncIterable, AsyncIterator 

4from contextlib import asynccontextmanager 

5from dataclasses import dataclass, field 

6from datetime import datetime, timezone 

7from json import JSONDecodeError, loads as json_loads 

8from typing import Any, Literal, Union, cast, overload 

9 

10from httpx import AsyncClient as AsyncHTTPClient 

11from typing_extensions import assert_never 

12 

13from .. import UnexpectedModelBehavior, _utils, usage 

14from .._utils import guard_tool_call_id as _guard_tool_call_id 

15from ..messages import ( 

16 ModelMessage, 

17 ModelRequest, 

18 ModelResponse, 

19 ModelResponsePart, 

20 ModelResponseStreamEvent, 

21 RetryPromptPart, 

22 SystemPromptPart, 

23 TextPart, 

24 ToolCallPart, 

25 ToolReturnPart, 

26 UserPromptPart, 

27) 

28from ..settings import ModelSettings 

29from ..tools import ToolDefinition 

30from . import ( 

31 AgentModel, 

32 Model, 

33 StreamedResponse, 

34 cached_async_http_client, 

35 check_allow_model_requests, 

36) 

37 

38try: 

39 from anthropic import NOT_GIVEN, AsyncAnthropic, AsyncStream 

40 from anthropic.types import ( 

41 Message as AnthropicMessage, 

42 MessageParam, 

43 MetadataParam, 

44 RawContentBlockDeltaEvent, 

45 RawContentBlockStartEvent, 

46 RawContentBlockStopEvent, 

47 RawMessageDeltaEvent, 

48 RawMessageStartEvent, 

49 RawMessageStopEvent, 

50 RawMessageStreamEvent, 

51 TextBlock, 

52 TextBlockParam, 

53 TextDelta, 

54 ToolChoiceParam, 

55 ToolParam, 

56 ToolResultBlockParam, 

57 ToolUseBlock, 

58 ToolUseBlockParam, 

59 ) 

60except ImportError as _import_error: 

61 raise ImportError( 

62 'Please install `anthropic` to use the Anthropic model, ' 

63 "you can use the `anthropic` optional group — `pip install 'pydantic-ai-slim[anthropic]'`" 

64 ) from _import_error 

65 

66LatestAnthropicModelNames = Literal[ 

67 'claude-3-5-haiku-latest', 

68 'claude-3-5-sonnet-latest', 

69 'claude-3-opus-latest', 

70] 

71"""Latest named Anthropic models.""" 

72 

73AnthropicModelName = Union[str, LatestAnthropicModelNames] 

74"""Possible Anthropic model names. 

75 

76Since Anthropic supports a variety of date-stamped models, we explicitly list the latest models but 

77allow any name in the type hints. 

78Since [the Anthropic docs](https://docs.anthropic.com/en/docs/about-claude/models) for a full list. 

79""" 

80 

81 

82class AnthropicModelSettings(ModelSettings): 

83 """Settings used for an Anthropic model request.""" 

84 

85 anthropic_metadata: MetadataParam 

86 """An object describing metadata about the request. 

87 

88 Contains `user_id`, an external identifier for the user who is associated with the request.""" 

89 

90 

91@dataclass(init=False) 

92class AnthropicModel(Model): 

93 """A model that uses the Anthropic API. 

94 

95 Internally, this uses the [Anthropic Python client](https://github.com/anthropics/anthropic-sdk-python) to interact with the API. 

96 

97 Apart from `__init__`, all methods are private or match those of the base class. 

98 

99 !!! note 

100 The `AnthropicModel` class does not yet support streaming responses. 

101 We anticipate adding support for streaming responses in a near-term future release. 

102 """ 

103 

104 model_name: AnthropicModelName 

105 client: AsyncAnthropic = field(repr=False) 

106 

107 def __init__( 

108 self, 

109 model_name: AnthropicModelName, 

110 *, 

111 api_key: str | None = None, 

112 anthropic_client: AsyncAnthropic | None = None, 

113 http_client: AsyncHTTPClient | None = None, 

114 ): 

115 """Initialize an Anthropic model. 

116 

117 Args: 

118 model_name: The name of the Anthropic model to use. List of model names available 

119 [here](https://docs.anthropic.com/en/docs/about-claude/models). 

120 api_key: The API key to use for authentication, if not provided, the `ANTHROPIC_API_KEY` environment variable 

121 will be used if available. 

122 anthropic_client: An existing 

123 [`AsyncAnthropic`](https://github.com/anthropics/anthropic-sdk-python?tab=readme-ov-file#async-usage) 

124 client to use, if provided, `api_key` and `http_client` must be `None`. 

125 http_client: An existing `httpx.AsyncClient` to use for making HTTP requests. 

126 """ 

127 self.model_name = model_name 

128 if anthropic_client is not None: 

129 assert http_client is None, 'Cannot provide both `anthropic_client` and `http_client`' 

130 assert api_key is None, 'Cannot provide both `anthropic_client` and `api_key`' 

131 self.client = anthropic_client 

132 elif http_client is not None: 132 ↛ 133line 132 didn't jump to line 133 because the condition on line 132 was never true

133 self.client = AsyncAnthropic(api_key=api_key, http_client=http_client) 

134 else: 

135 self.client = AsyncAnthropic(api_key=api_key, http_client=cached_async_http_client()) 

136 

137 async def agent_model( 

138 self, 

139 *, 

140 function_tools: list[ToolDefinition], 

141 allow_text_result: bool, 

142 result_tools: list[ToolDefinition], 

143 ) -> AgentModel: 

144 check_allow_model_requests() 

145 tools = [self._map_tool_definition(r) for r in function_tools] 

146 if result_tools: 

147 tools += [self._map_tool_definition(r) for r in result_tools] 

148 return AnthropicAgentModel( 

149 self.client, 

150 self.model_name, 

151 allow_text_result, 

152 tools, 

153 ) 

154 

155 def name(self) -> str: 

156 return f'anthropic:{self.model_name}' 

157 

158 @staticmethod 

159 def _map_tool_definition(f: ToolDefinition) -> ToolParam: 

160 return { 

161 'name': f.name, 

162 'description': f.description, 

163 'input_schema': f.parameters_json_schema, 

164 } 

165 

166 

167@dataclass 

168class AnthropicAgentModel(AgentModel): 

169 """Implementation of `AgentModel` for Anthropic models.""" 

170 

171 client: AsyncAnthropic 

172 model_name: AnthropicModelName 

173 allow_text_result: bool 

174 tools: list[ToolParam] 

175 

176 async def request( 

177 self, messages: list[ModelMessage], model_settings: ModelSettings | None 

178 ) -> tuple[ModelResponse, usage.Usage]: 

179 response = await self._messages_create(messages, False, cast(AnthropicModelSettings, model_settings or {})) 

180 return self._process_response(response), _map_usage(response) 

181 

182 @asynccontextmanager 

183 async def request_stream( 

184 self, messages: list[ModelMessage], model_settings: ModelSettings | None 

185 ) -> AsyncIterator[StreamedResponse]: 

186 response = await self._messages_create(messages, True, cast(AnthropicModelSettings, model_settings or {})) 

187 async with response: 

188 yield await self._process_streamed_response(response) 

189 

190 @overload 

191 async def _messages_create( 

192 self, messages: list[ModelMessage], stream: Literal[True], model_settings: AnthropicModelSettings 

193 ) -> AsyncStream[RawMessageStreamEvent]: 

194 pass 

195 

196 @overload 

197 async def _messages_create( 

198 self, messages: list[ModelMessage], stream: Literal[False], model_settings: AnthropicModelSettings 

199 ) -> AnthropicMessage: 

200 pass 

201 

202 async def _messages_create( 

203 self, messages: list[ModelMessage], stream: bool, model_settings: AnthropicModelSettings 

204 ) -> AnthropicMessage | AsyncStream[RawMessageStreamEvent]: 

205 # standalone function to make it easier to override 

206 tool_choice: ToolChoiceParam | None 

207 

208 if not self.tools: 

209 tool_choice = None 

210 else: 

211 if not self.allow_text_result: 

212 tool_choice = {'type': 'any'} 

213 else: 

214 tool_choice = {'type': 'auto'} 

215 

216 if (allow_parallel_tool_calls := model_settings.get('parallel_tool_calls')) is not None: 

217 tool_choice['disable_parallel_tool_use'] = not allow_parallel_tool_calls 

218 

219 system_prompt, anthropic_messages = self._map_message(messages) 

220 

221 return await self.client.messages.create( 

222 max_tokens=model_settings.get('max_tokens', 1024), 

223 system=system_prompt or NOT_GIVEN, 

224 messages=anthropic_messages, 

225 model=self.model_name, 

226 tools=self.tools or NOT_GIVEN, 

227 tool_choice=tool_choice or NOT_GIVEN, 

228 stream=stream, 

229 temperature=model_settings.get('temperature', NOT_GIVEN), 

230 top_p=model_settings.get('top_p', NOT_GIVEN), 

231 timeout=model_settings.get('timeout', NOT_GIVEN), 

232 metadata=model_settings.get('anthropic_metadata', NOT_GIVEN), 

233 ) 

234 

235 def _process_response(self, response: AnthropicMessage) -> ModelResponse: 

236 """Process a non-streamed response, and prepare a message to return.""" 

237 items: list[ModelResponsePart] = [] 

238 for item in response.content: 

239 if isinstance(item, TextBlock): 

240 items.append(TextPart(content=item.text)) 

241 else: 

242 assert isinstance(item, ToolUseBlock), 'unexpected item type' 

243 items.append( 

244 ToolCallPart( 

245 tool_name=item.name, 

246 args=cast(dict[str, Any], item.input), 

247 tool_call_id=item.id, 

248 ) 

249 ) 

250 

251 return ModelResponse(items, model_name=self.model_name) 

252 

253 async def _process_streamed_response(self, response: AsyncStream[RawMessageStreamEvent]) -> StreamedResponse: 

254 peekable_response = _utils.PeekableAsyncStream(response) 

255 first_chunk = await peekable_response.peek() 

256 if isinstance(first_chunk, _utils.Unset): 256 ↛ 257line 256 didn't jump to line 257 because the condition on line 256 was never true

257 raise UnexpectedModelBehavior('Streamed response ended without content or tool calls') 

258 

259 # Since Anthropic doesn't provide a timestamp in the message, we'll use the current time 

260 timestamp = datetime.now(tz=timezone.utc) 

261 return AnthropicStreamedResponse(_model_name=self.model_name, _response=peekable_response, _timestamp=timestamp) 

262 

263 @staticmethod 

264 def _map_message(messages: list[ModelMessage]) -> tuple[str, list[MessageParam]]: 

265 """Just maps a `pydantic_ai.Message` to a `anthropic.types.MessageParam`.""" 

266 system_prompt: str = '' 

267 anthropic_messages: list[MessageParam] = [] 

268 for m in messages: 

269 if isinstance(m, ModelRequest): 

270 for part in m.parts: 

271 if isinstance(part, SystemPromptPart): 

272 system_prompt += part.content 

273 elif isinstance(part, UserPromptPart): 

274 anthropic_messages.append(MessageParam(role='user', content=part.content)) 

275 elif isinstance(part, ToolReturnPart): 

276 anthropic_messages.append( 

277 MessageParam( 

278 role='user', 

279 content=[ 

280 ToolResultBlockParam( 

281 tool_use_id=_guard_tool_call_id(t=part, model_source='Anthropic'), 

282 type='tool_result', 

283 content=part.model_response_str(), 

284 is_error=False, 

285 ) 

286 ], 

287 ) 

288 ) 

289 elif isinstance(part, RetryPromptPart): 289 ↛ 270line 289 didn't jump to line 270 because the condition on line 289 was always true

290 if part.tool_name is None: 290 ↛ 291line 290 didn't jump to line 291 because the condition on line 290 was never true

291 anthropic_messages.append(MessageParam(role='user', content=part.model_response())) 

292 else: 

293 anthropic_messages.append( 

294 MessageParam( 

295 role='user', 

296 content=[ 

297 ToolResultBlockParam( 

298 tool_use_id=_guard_tool_call_id(t=part, model_source='Anthropic'), 

299 type='tool_result', 

300 content=part.model_response(), 

301 is_error=True, 

302 ), 

303 ], 

304 ) 

305 ) 

306 elif isinstance(m, ModelResponse): 

307 content: list[TextBlockParam | ToolUseBlockParam] = [] 

308 for item in m.parts: 

309 if isinstance(item, TextPart): 

310 content.append(TextBlockParam(text=item.content, type='text')) 

311 else: 

312 assert isinstance(item, ToolCallPart) 

313 content.append(_map_tool_call(item)) 

314 anthropic_messages.append(MessageParam(role='assistant', content=content)) 

315 else: 

316 assert_never(m) 

317 return system_prompt, anthropic_messages 

318 

319 

320def _map_tool_call(t: ToolCallPart) -> ToolUseBlockParam: 

321 return ToolUseBlockParam( 

322 id=_guard_tool_call_id(t=t, model_source='Anthropic'), 

323 type='tool_use', 

324 name=t.tool_name, 

325 input=t.args_as_dict(), 

326 ) 

327 

328 

329def _map_usage(message: AnthropicMessage | RawMessageStreamEvent) -> usage.Usage: 

330 if isinstance(message, AnthropicMessage): 

331 response_usage = message.usage 

332 else: 

333 if isinstance(message, RawMessageStartEvent): 

334 response_usage = message.message.usage 

335 elif isinstance(message, RawMessageDeltaEvent): 

336 response_usage = message.usage 

337 else: 

338 # No usage information provided in: 

339 # - RawMessageStopEvent 

340 # - RawContentBlockStartEvent 

341 # - RawContentBlockDeltaEvent 

342 # - RawContentBlockStopEvent 

343 response_usage = None 

344 

345 if response_usage is None: 

346 return usage.Usage() 

347 

348 request_tokens = getattr(response_usage, 'input_tokens', None) 

349 

350 return usage.Usage( 

351 # Usage coming from the RawMessageDeltaEvent doesn't have input token data, hence this getattr 

352 request_tokens=request_tokens, 

353 response_tokens=response_usage.output_tokens, 

354 total_tokens=(request_tokens or 0) + response_usage.output_tokens, 

355 ) 

356 

357 

358@dataclass 

359class AnthropicStreamedResponse(StreamedResponse): 

360 """Implementation of `StreamedResponse` for Anthropic models.""" 

361 

362 _response: AsyncIterable[RawMessageStreamEvent] 

363 _timestamp: datetime 

364 

365 async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: 

366 current_block: TextBlock | ToolUseBlock | None = None 

367 current_json: str = '' 

368 

369 async for event in self._response: 

370 self._usage += _map_usage(event) 

371 

372 if isinstance(event, RawContentBlockStartEvent): 

373 current_block = event.content_block 

374 if isinstance(current_block, TextBlock) and current_block.text: 

375 yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=current_block.text) 

376 elif isinstance(current_block, ToolUseBlock): 376 ↛ 369line 376 didn't jump to line 369 because the condition on line 376 was always true

377 maybe_event = self._parts_manager.handle_tool_call_delta( 

378 vendor_part_id=current_block.id, 

379 tool_name=current_block.name, 

380 args=cast(dict[str, Any], current_block.input), 

381 tool_call_id=current_block.id, 

382 ) 

383 if maybe_event is not None: 383 ↛ 369line 383 didn't jump to line 369 because the condition on line 383 was always true

384 yield maybe_event 

385 

386 elif isinstance(event, RawContentBlockDeltaEvent): 

387 if isinstance(event.delta, TextDelta): 387 ↛ 388line 387 didn't jump to line 388 because the condition on line 387 was never true

388 yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=event.delta.text) 

389 elif ( 389 ↛ 369line 389 didn't jump to line 369

390 current_block and event.delta.type == 'input_json_delta' and isinstance(current_block, ToolUseBlock) 

391 ): 

392 # Try to parse the JSON immediately, otherwise cache the value for later. This handles 

393 # cases where the JSON is not currently valid but will be valid once we stream more tokens. 

394 try: 

395 parsed_args = json_loads(current_json + event.delta.partial_json) 

396 current_json = '' 

397 except JSONDecodeError: 

398 current_json += event.delta.partial_json 

399 continue 

400 

401 # For tool calls, we need to handle partial JSON updates 

402 maybe_event = self._parts_manager.handle_tool_call_delta( 

403 vendor_part_id=current_block.id, 

404 tool_name='', 

405 args=parsed_args, 

406 tool_call_id=current_block.id, 

407 ) 

408 if maybe_event is not None: 408 ↛ 369line 408 didn't jump to line 369 because the condition on line 408 was always true

409 yield maybe_event 

410 

411 elif isinstance(event, (RawContentBlockStopEvent, RawMessageStopEvent)): 

412 current_block = None 

413 

414 def timestamp(self) -> datetime: 

415 return self._timestamp