Coverage for tests/models/test_anthropic.py: 97.33%
140 statements
« prev ^ index » next coverage.py v7.6.7, created at 2025-01-25 16:43 +0000
« prev ^ index » next coverage.py v7.6.7, created at 2025-01-25 16:43 +0000
1from __future__ import annotations as _annotations
3import json
4from dataclasses import dataclass, field
5from datetime import timezone
6from functools import cached_property
7from typing import Any, TypeVar, cast
9import pytest
10from inline_snapshot import snapshot
12from pydantic_ai import Agent, ModelRetry
13from pydantic_ai.messages import (
14 ModelRequest,
15 ModelResponse,
16 RetryPromptPart,
17 SystemPromptPart,
18 TextPart,
19 ToolCallPart,
20 ToolReturnPart,
21 UserPromptPart,
22)
23from pydantic_ai.result import Usage
24from pydantic_ai.settings import ModelSettings
26from ..conftest import IsNow, try_import
27from .mock_async_stream import MockAsyncStream
29with try_import() as imports_successful:
30 from anthropic import NOT_GIVEN, AsyncAnthropic
31 from anthropic.types import (
32 ContentBlock,
33 InputJSONDelta,
34 Message as AnthropicMessage,
35 MessageDeltaUsage,
36 RawContentBlockDeltaEvent,
37 RawContentBlockStartEvent,
38 RawContentBlockStopEvent,
39 RawMessageDeltaEvent,
40 RawMessageStartEvent,
41 RawMessageStopEvent,
42 RawMessageStreamEvent,
43 TextBlock,
44 ToolUseBlock,
45 Usage as AnthropicUsage,
46 )
47 from anthropic.types.raw_message_delta_event import Delta
49 from pydantic_ai.models.anthropic import AnthropicModel, AnthropicModelSettings
51pytestmark = [
52 pytest.mark.skipif(not imports_successful(), reason='anthropic not installed'),
53 pytest.mark.anyio,
54]
56# Type variable for generic AsyncStream
57T = TypeVar('T')
60def test_init():
61 m = AnthropicModel('claude-3-5-haiku-latest', api_key='foobar')
62 assert m.client.api_key == 'foobar'
63 assert m.name() == 'anthropic:claude-3-5-haiku-latest'
66@dataclass
67class MockAnthropic:
68 messages_: AnthropicMessage | list[AnthropicMessage] | None = None
69 stream: list[RawMessageStreamEvent] | list[list[RawMessageStreamEvent]] | None = None
70 index = 0
71 chat_completion_kwargs: list[dict[str, Any]] = field(default_factory=list)
73 @cached_property
74 def messages(self) -> Any:
75 return type('Messages', (), {'create': self.messages_create})
77 @classmethod
78 def create_mock(cls, messages_: AnthropicMessage | list[AnthropicMessage]) -> AsyncAnthropic:
79 return cast(AsyncAnthropic, cls(messages_=messages_))
81 @classmethod
82 def create_stream_mock(
83 cls, stream: list[RawMessageStreamEvent] | list[list[RawMessageStreamEvent]]
84 ) -> AsyncAnthropic:
85 return cast(AsyncAnthropic, cls(stream=stream))
87 async def messages_create(
88 self, *_args: Any, stream: bool = False, **kwargs: Any
89 ) -> AnthropicMessage | MockAsyncStream[RawMessageStreamEvent]:
90 self.chat_completion_kwargs.append({k: v for k, v in kwargs.items() if v is not NOT_GIVEN})
92 if stream:
93 assert self.stream is not None, 'you can only use `stream=True` if `stream` is provided'
94 # noinspection PyUnresolvedReferences
95 if isinstance(self.stream[0], list): 95 ↛ 99line 95 didn't jump to line 99 because the condition on line 95 was always true
96 indexed_stream = cast(list[RawMessageStreamEvent], self.stream[self.index])
97 response = MockAsyncStream(iter(indexed_stream))
98 else:
99 response = MockAsyncStream(iter(cast(list[RawMessageStreamEvent], self.stream)))
100 else:
101 assert self.messages_ is not None, '`messages` must be provided'
102 if isinstance(self.messages_, list):
103 response = self.messages_[self.index]
104 else:
105 response = self.messages_
106 self.index += 1
107 return response
110def completion_message(content: list[ContentBlock], usage: AnthropicUsage) -> AnthropicMessage:
111 return AnthropicMessage(
112 id='123',
113 content=content,
114 model='claude-3-5-haiku-latest',
115 role='assistant',
116 stop_reason='end_turn',
117 type='message',
118 usage=usage,
119 )
122async def test_sync_request_text_response(allow_model_requests: None):
123 c = completion_message([TextBlock(text='world', type='text')], AnthropicUsage(input_tokens=5, output_tokens=10))
124 mock_client = MockAnthropic.create_mock(c)
125 m = AnthropicModel('claude-3-5-haiku-latest', anthropic_client=mock_client)
126 agent = Agent(m)
128 result = await agent.run('hello')
129 assert result.data == 'world'
130 assert result.usage() == snapshot(Usage(requests=1, request_tokens=5, response_tokens=10, total_tokens=15))
132 # reset the index so we get the same response again
133 mock_client.index = 0 # type: ignore
135 result = await agent.run('hello', message_history=result.new_messages())
136 assert result.data == 'world'
137 assert result.usage() == snapshot(Usage(requests=1, request_tokens=5, response_tokens=10, total_tokens=15))
138 assert result.all_messages() == snapshot(
139 [
140 ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]),
141 ModelResponse(
142 parts=[TextPart(content='world')],
143 model_name='claude-3-5-haiku-latest',
144 timestamp=IsNow(tz=timezone.utc),
145 ),
146 ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]),
147 ModelResponse(
148 parts=[TextPart(content='world')],
149 model_name='claude-3-5-haiku-latest',
150 timestamp=IsNow(tz=timezone.utc),
151 ),
152 ]
153 )
156async def test_async_request_text_response(allow_model_requests: None):
157 c = completion_message(
158 [TextBlock(text='world', type='text')],
159 usage=AnthropicUsage(input_tokens=3, output_tokens=5),
160 )
161 mock_client = MockAnthropic.create_mock(c)
162 m = AnthropicModel('claude-3-5-haiku-latest', anthropic_client=mock_client)
163 agent = Agent(m)
165 result = await agent.run('hello')
166 assert result.data == 'world'
167 assert result.usage() == snapshot(Usage(requests=1, request_tokens=3, response_tokens=5, total_tokens=8))
170async def test_request_structured_response(allow_model_requests: None):
171 c = completion_message(
172 [ToolUseBlock(id='123', input={'response': [1, 2, 3]}, name='final_result', type='tool_use')],
173 usage=AnthropicUsage(input_tokens=3, output_tokens=5),
174 )
175 mock_client = MockAnthropic.create_mock(c)
176 m = AnthropicModel('claude-3-5-haiku-latest', anthropic_client=mock_client)
177 agent = Agent(m, result_type=list[int])
179 result = await agent.run('hello')
180 assert result.data == [1, 2, 3]
181 assert result.all_messages() == snapshot(
182 [
183 ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]),
184 ModelResponse(
185 parts=[
186 ToolCallPart(
187 tool_name='final_result',
188 args={'response': [1, 2, 3]},
189 tool_call_id='123',
190 )
191 ],
192 model_name='claude-3-5-haiku-latest',
193 timestamp=IsNow(tz=timezone.utc),
194 ),
195 ModelRequest(
196 parts=[
197 ToolReturnPart(
198 tool_name='final_result',
199 content='Final result processed.',
200 tool_call_id='123',
201 timestamp=IsNow(tz=timezone.utc),
202 )
203 ]
204 ),
205 ]
206 )
209async def test_request_tool_call(allow_model_requests: None):
210 responses = [
211 completion_message(
212 [ToolUseBlock(id='1', input={'loc_name': 'San Francisco'}, name='get_location', type='tool_use')],
213 usage=AnthropicUsage(input_tokens=2, output_tokens=1),
214 ),
215 completion_message(
216 [ToolUseBlock(id='2', input={'loc_name': 'London'}, name='get_location', type='tool_use')],
217 usage=AnthropicUsage(input_tokens=3, output_tokens=2),
218 ),
219 completion_message(
220 [TextBlock(text='final response', type='text')],
221 usage=AnthropicUsage(input_tokens=3, output_tokens=5),
222 ),
223 ]
225 mock_client = MockAnthropic.create_mock(responses)
226 m = AnthropicModel('claude-3-5-haiku-latest', anthropic_client=mock_client)
227 agent = Agent(m, system_prompt='this is the system prompt')
229 @agent.tool_plain
230 async def get_location(loc_name: str) -> str:
231 if loc_name == 'London':
232 return json.dumps({'lat': 51, 'lng': 0})
233 else:
234 raise ModelRetry('Wrong location, please try again')
236 result = await agent.run('hello')
237 assert result.data == 'final response'
238 assert result.all_messages() == snapshot(
239 [
240 ModelRequest(
241 parts=[
242 SystemPromptPart(content='this is the system prompt'),
243 UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc)),
244 ]
245 ),
246 ModelResponse(
247 parts=[
248 ToolCallPart(
249 tool_name='get_location',
250 args={'loc_name': 'San Francisco'},
251 tool_call_id='1',
252 )
253 ],
254 model_name='claude-3-5-haiku-latest',
255 timestamp=IsNow(tz=timezone.utc),
256 ),
257 ModelRequest(
258 parts=[
259 RetryPromptPart(
260 content='Wrong location, please try again',
261 tool_name='get_location',
262 tool_call_id='1',
263 timestamp=IsNow(tz=timezone.utc),
264 )
265 ]
266 ),
267 ModelResponse(
268 parts=[
269 ToolCallPart(
270 tool_name='get_location',
271 args={'loc_name': 'London'},
272 tool_call_id='2',
273 )
274 ],
275 model_name='claude-3-5-haiku-latest',
276 timestamp=IsNow(tz=timezone.utc),
277 ),
278 ModelRequest(
279 parts=[
280 ToolReturnPart(
281 tool_name='get_location',
282 content='{"lat": 51, "lng": 0}',
283 tool_call_id='2',
284 timestamp=IsNow(tz=timezone.utc),
285 )
286 ]
287 ),
288 ModelResponse(
289 parts=[TextPart(content='final response')],
290 model_name='claude-3-5-haiku-latest',
291 timestamp=IsNow(tz=timezone.utc),
292 ),
293 ]
294 )
297def get_mock_chat_completion_kwargs(async_anthropic: AsyncAnthropic) -> list[dict[str, Any]]:
298 if isinstance(async_anthropic, MockAnthropic):
299 return async_anthropic.chat_completion_kwargs
300 else: # pragma: no cover
301 raise RuntimeError('Not a MockOpenAI instance')
304@pytest.mark.parametrize('parallel_tool_calls', [True, False])
305async def test_parallel_tool_calls(allow_model_requests: None, parallel_tool_calls: bool) -> None:
306 responses = [
307 completion_message(
308 [ToolUseBlock(id='1', input={'loc_name': 'San Francisco'}, name='get_location', type='tool_use')],
309 usage=AnthropicUsage(input_tokens=2, output_tokens=1),
310 ),
311 completion_message(
312 [TextBlock(text='final response', type='text')],
313 usage=AnthropicUsage(input_tokens=3, output_tokens=5),
314 ),
315 ]
317 mock_client = MockAnthropic.create_mock(responses)
318 m = AnthropicModel('claude-3-5-haiku-latest', anthropic_client=mock_client)
319 agent = Agent(m, model_settings=ModelSettings(parallel_tool_calls=parallel_tool_calls))
321 @agent.tool_plain
322 async def get_location(loc_name: str) -> str:
323 if loc_name == 'London': 323 ↛ 324line 323 didn't jump to line 324 because the condition on line 323 was never true
324 return json.dumps({'lat': 51, 'lng': 0})
325 else:
326 raise ModelRetry('Wrong location, please try again')
328 await agent.run('hello')
329 assert get_mock_chat_completion_kwargs(mock_client)[0]['tool_choice']['disable_parallel_tool_use'] == (
330 not parallel_tool_calls
331 )
334async def test_anthropic_specific_metadata(allow_model_requests: None) -> None:
335 c = completion_message([TextBlock(text='world', type='text')], AnthropicUsage(input_tokens=5, output_tokens=10))
336 mock_client = MockAnthropic.create_mock(c)
337 m = AnthropicModel('claude-3-5-haiku-latest', anthropic_client=mock_client)
338 agent = Agent(m)
340 result = await agent.run('hello', model_settings=AnthropicModelSettings(anthropic_metadata={'user_id': '123'}))
341 assert result.data == 'world'
342 assert get_mock_chat_completion_kwargs(mock_client)[0]['metadata']['user_id'] == '123'
345async def test_stream_structured(allow_model_requests: None):
346 """Test streaming structured responses with Anthropic's API.
348 This test simulates how Anthropic streams tool calls:
349 1. Message start
350 2. Tool block start with initial data
351 3. Tool block delta with additional data
352 4. Tool block stop
353 5. Update usage
354 6. Message stop
355 """
356 stream: list[RawMessageStreamEvent] = [
357 RawMessageStartEvent(
358 type='message_start',
359 message=AnthropicMessage(
360 id='msg_123',
361 model='claude-3-5-haiku-latest',
362 role='assistant',
363 type='message',
364 content=[],
365 stop_reason=None,
366 usage=AnthropicUsage(input_tokens=20, output_tokens=0),
367 ),
368 ),
369 # Start tool block with initial data
370 RawContentBlockStartEvent(
371 type='content_block_start',
372 index=0,
373 content_block=ToolUseBlock(type='tool_use', id='tool_1', name='my_tool', input={'first': 'One'}),
374 ),
375 # Add more data through an incomplete JSON delta
376 RawContentBlockDeltaEvent(
377 type='content_block_delta',
378 index=0,
379 delta=InputJSONDelta(type='input_json_delta', partial_json='{"second":'),
380 ),
381 RawContentBlockDeltaEvent(
382 type='content_block_delta',
383 index=0,
384 delta=InputJSONDelta(type='input_json_delta', partial_json='"Two"}'),
385 ),
386 # Mark tool block as complete
387 RawContentBlockStopEvent(type='content_block_stop', index=0),
388 # Update the top-level message with usage
389 RawMessageDeltaEvent(
390 type='message_delta',
391 delta=Delta(
392 stop_reason='end_turn',
393 ),
394 usage=MessageDeltaUsage(
395 output_tokens=5,
396 ),
397 ),
398 # Mark message as complete
399 RawMessageStopEvent(type='message_stop'),
400 ]
402 done_stream: list[RawMessageStreamEvent] = [
403 RawMessageStartEvent(
404 type='message_start',
405 message=AnthropicMessage(
406 id='msg_123',
407 model='claude-3-5-haiku-latest',
408 role='assistant',
409 type='message',
410 content=[],
411 stop_reason=None,
412 usage=AnthropicUsage(input_tokens=0, output_tokens=0),
413 ),
414 ),
415 # Text block with final data
416 RawContentBlockStartEvent(
417 type='content_block_start',
418 index=0,
419 content_block=TextBlock(type='text', text='FINAL_PAYLOAD'),
420 ),
421 RawContentBlockStopEvent(type='content_block_stop', index=0),
422 RawMessageStopEvent(type='message_stop'),
423 ]
425 mock_client = MockAnthropic.create_stream_mock([stream, done_stream])
426 m = AnthropicModel('claude-3-5-haiku-latest', anthropic_client=mock_client)
427 agent = Agent(m)
429 tool_called = False
431 @agent.tool_plain
432 async def my_tool(first: str, second: str) -> int:
433 nonlocal tool_called
434 tool_called = True
435 return len(first) + len(second)
437 async with agent.run_stream('') as result:
438 assert not result.is_complete
439 chunks = [c async for c in result.stream(debounce_by=None)]
441 # The tool output doesn't echo any content to the stream, so we only get the final payload once when
442 # the block starts and once when it ends.
443 assert chunks == snapshot(
444 [
445 'FINAL_PAYLOAD',
446 'FINAL_PAYLOAD',
447 ]
448 )
449 assert result.is_complete
450 assert result.usage() == snapshot(Usage(requests=2, request_tokens=20, response_tokens=5, total_tokens=25))
451 assert tool_called