Coverage for pydantic_ai_slim/pydantic_ai/models/cohere.py: 93.04%
116 statements
« prev ^ index » next coverage.py v7.6.7, created at 2025-01-30 19:21 +0000
« prev ^ index » next coverage.py v7.6.7, created at 2025-01-30 19:21 +0000
1from __future__ import annotations as _annotations
3from collections.abc import Iterable
4from dataclasses import dataclass, field
5from itertools import chain
6from typing import Literal, Union, cast
8from cohere import TextAssistantMessageContentItem
9from httpx import AsyncClient as AsyncHTTPClient
10from typing_extensions import assert_never
12from .. import result
13from .._utils import guard_tool_call_id as _guard_tool_call_id
14from ..messages import (
15 ModelMessage,
16 ModelRequest,
17 ModelResponse,
18 ModelResponsePart,
19 RetryPromptPart,
20 SystemPromptPart,
21 TextPart,
22 ToolCallPart,
23 ToolReturnPart,
24 UserPromptPart,
25)
26from ..settings import ModelSettings
27from ..tools import ToolDefinition
28from . import (
29 AgentModel,
30 Model,
31 check_allow_model_requests,
32)
34try:
35 from cohere import (
36 AssistantChatMessageV2,
37 AsyncClientV2,
38 ChatMessageV2,
39 ChatResponse,
40 SystemChatMessageV2,
41 ToolCallV2,
42 ToolCallV2Function,
43 ToolChatMessageV2,
44 ToolV2,
45 ToolV2Function,
46 UserChatMessageV2,
47 )
48 from cohere.v2.client import OMIT
49except ImportError as _import_error:
50 raise ImportError(
51 'Please install `cohere` to use the Cohere model, '
52 "you can use the `cohere` optional group — `pip install 'pydantic-ai-slim[cohere]'`"
53 ) from _import_error
55NamedCohereModels = Literal[
56 'c4ai-aya-expanse-32b',
57 'c4ai-aya-expanse-8b',
58 'command',
59 'command-light',
60 'command-light-nightly',
61 'command-nightly',
62 'command-r',
63 'command-r-03-2024',
64 'command-r-08-2024',
65 'command-r-plus',
66 'command-r-plus-04-2024',
67 'command-r-plus-08-2024',
68 'command-r7b-12-2024',
69]
70"""Latest / most popular named Cohere models."""
72CohereModelName = Union[NamedCohereModels, str]
75class CohereModelSettings(ModelSettings):
76 """Settings used for a Cohere model request."""
78 # This class is a placeholder for any future cohere-specific settings
81@dataclass(init=False)
82class CohereModel(Model):
83 """A model that uses the Cohere API.
85 Internally, this uses the [Cohere Python client](
86 https://github.com/cohere-ai/cohere-python) to interact with the API.
88 Apart from `__init__`, all methods are private or match those of the base class.
89 """
91 model_name: CohereModelName
92 client: AsyncClientV2 = field(repr=False)
94 def __init__(
95 self,
96 model_name: CohereModelName,
97 *,
98 api_key: str | None = None,
99 cohere_client: AsyncClientV2 | None = None,
100 http_client: AsyncHTTPClient | None = None,
101 ):
102 """Initialize an Cohere model.
104 Args:
105 model_name: The name of the Cohere model to use. List of model names
106 available [here](https://docs.cohere.com/docs/models#command).
107 api_key: The API key to use for authentication, if not provided, the
108 `CO_API_KEY` environment variable will be used if available.
109 cohere_client: An existing Cohere async client to use. If provided,
110 `api_key` and `http_client` must be `None`.
111 http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
112 """
113 self.model_name: CohereModelName = model_name
114 if cohere_client is not None:
115 assert http_client is None, 'Cannot provide both `cohere_client` and `http_client`'
116 assert api_key is None, 'Cannot provide both `cohere_client` and `api_key`'
117 self.client = cohere_client
118 else:
119 self.client = AsyncClientV2(api_key=api_key, httpx_client=http_client) # type: ignore
121 async def agent_model(
122 self,
123 *,
124 function_tools: list[ToolDefinition],
125 allow_text_result: bool,
126 result_tools: list[ToolDefinition],
127 ) -> AgentModel:
128 check_allow_model_requests()
129 tools = [self._map_tool_definition(r) for r in function_tools]
130 if result_tools:
131 tools += [self._map_tool_definition(r) for r in result_tools]
132 return CohereAgentModel(
133 self.client,
134 self.model_name,
135 allow_text_result,
136 tools,
137 )
139 def name(self) -> str:
140 return f'cohere:{self.model_name}'
142 @staticmethod
143 def _map_tool_definition(f: ToolDefinition) -> ToolV2:
144 return ToolV2(
145 type='function',
146 function=ToolV2Function(
147 name=f.name,
148 description=f.description,
149 parameters=f.parameters_json_schema,
150 ),
151 )
154@dataclass
155class CohereAgentModel(AgentModel):
156 """Implementation of `AgentModel` for Cohere models."""
158 client: AsyncClientV2
159 model_name: CohereModelName
160 allow_text_result: bool
161 tools: list[ToolV2]
163 async def request(
164 self, messages: list[ModelMessage], model_settings: ModelSettings | None
165 ) -> tuple[ModelResponse, result.Usage]:
166 response = await self._chat(messages, cast(CohereModelSettings, model_settings or {}))
167 return self._process_response(response), _map_usage(response)
169 async def _chat(
170 self,
171 messages: list[ModelMessage],
172 model_settings: CohereModelSettings,
173 ) -> ChatResponse:
174 cohere_messages = list(chain(*(self._map_message(m) for m in messages)))
175 return await self.client.chat(
176 model=self.model_name,
177 messages=cohere_messages,
178 tools=self.tools or OMIT,
179 max_tokens=model_settings.get('max_tokens', OMIT),
180 temperature=model_settings.get('temperature', OMIT),
181 p=model_settings.get('top_p', OMIT),
182 seed=model_settings.get('seed', OMIT),
183 presence_penalty=model_settings.get('presence_penalty', OMIT),
184 frequency_penalty=model_settings.get('frequency_penalty', OMIT),
185 )
187 def _process_response(self, response: ChatResponse) -> ModelResponse:
188 """Process a non-streamed response, and prepare a message to return."""
189 parts: list[ModelResponsePart] = []
190 if response.message.content is not None and len(response.message.content) > 0:
191 # While Cohere's API returns a list, it only does that for future proofing
192 # and currently only one item is being returned.
193 choice = response.message.content[0]
194 parts.append(TextPart(choice.text))
195 for c in response.message.tool_calls or []:
196 if c.function and c.function.name and c.function.arguments: 196 ↛ 195line 196 didn't jump to line 195 because the condition on line 196 was always true
197 parts.append(
198 ToolCallPart(
199 tool_name=c.function.name,
200 args=c.function.arguments,
201 tool_call_id=c.id,
202 )
203 )
204 return ModelResponse(parts=parts, model_name=self.model_name)
206 @classmethod
207 def _map_message(cls, message: ModelMessage) -> Iterable[ChatMessageV2]:
208 """Just maps a `pydantic_ai.Message` to a `cohere.ChatMessageV2`."""
209 if isinstance(message, ModelRequest):
210 yield from cls._map_user_message(message)
211 elif isinstance(message, ModelResponse):
212 texts: list[str] = []
213 tool_calls: list[ToolCallV2] = []
214 for item in message.parts:
215 if isinstance(item, TextPart):
216 texts.append(item.content)
217 elif isinstance(item, ToolCallPart):
218 tool_calls.append(_map_tool_call(item))
219 else:
220 assert_never(item)
221 message_param = AssistantChatMessageV2(role='assistant')
222 if texts:
223 message_param.content = [TextAssistantMessageContentItem(text='\n\n'.join(texts))]
224 if tool_calls:
225 message_param.tool_calls = tool_calls
226 yield message_param
227 else:
228 assert_never(message)
230 @classmethod
231 def _map_user_message(cls, message: ModelRequest) -> Iterable[ChatMessageV2]:
232 for part in message.parts:
233 if isinstance(part, SystemPromptPart):
234 yield SystemChatMessageV2(role='system', content=part.content)
235 elif isinstance(part, UserPromptPart):
236 yield UserChatMessageV2(role='user', content=part.content)
237 elif isinstance(part, ToolReturnPart):
238 yield ToolChatMessageV2(
239 role='tool',
240 tool_call_id=_guard_tool_call_id(t=part, model_source='Cohere'),
241 content=part.model_response_str(),
242 )
243 elif isinstance(part, RetryPromptPart):
244 if part.tool_name is None: 244 ↛ 245line 244 didn't jump to line 245 because the condition on line 244 was never true
245 yield UserChatMessageV2(role='user', content=part.model_response())
246 else:
247 yield ToolChatMessageV2(
248 role='tool',
249 tool_call_id=_guard_tool_call_id(t=part, model_source='Cohere'),
250 content=part.model_response(),
251 )
252 else:
253 assert_never(part)
256def _map_tool_call(t: ToolCallPart) -> ToolCallV2:
257 return ToolCallV2(
258 id=_guard_tool_call_id(t=t, model_source='Cohere'),
259 type='function',
260 function=ToolCallV2Function(
261 name=t.tool_name,
262 arguments=t.args_as_json_str(),
263 ),
264 )
267def _map_usage(response: ChatResponse) -> result.Usage:
268 usage = response.usage
269 if usage is None:
270 return result.Usage()
271 else:
272 details: dict[str, int] = {}
273 if usage.billed_units is not None:
274 if usage.billed_units.input_tokens: 274 ↛ 276line 274 didn't jump to line 276 because the condition on line 274 was always true
275 details['input_tokens'] = int(usage.billed_units.input_tokens)
276 if usage.billed_units.output_tokens: 276 ↛ 278line 276 didn't jump to line 278 because the condition on line 276 was always true
277 details['output_tokens'] = int(usage.billed_units.output_tokens)
278 if usage.billed_units.search_units: 278 ↛ 279line 278 didn't jump to line 279 because the condition on line 278 was never true
279 details['search_units'] = int(usage.billed_units.search_units)
280 if usage.billed_units.classifications: 280 ↛ 281line 280 didn't jump to line 281 because the condition on line 280 was never true
281 details['classifications'] = int(usage.billed_units.classifications)
283 request_tokens = int(usage.tokens.input_tokens) if usage.tokens and usage.tokens.input_tokens else None
284 response_tokens = int(usage.tokens.output_tokens) if usage.tokens and usage.tokens.output_tokens else None
285 return result.Usage(
286 request_tokens=request_tokens,
287 response_tokens=response_tokens,
288 total_tokens=(request_tokens or 0) + (response_tokens or 0),
289 details=details,
290 )