Coverage for pydantic_ai_slim/pydantic_ai/models/cohere.py: 94.15%
125 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-28 17:27 +0000
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-28 17:27 +0000
1from __future__ import annotations as _annotations
3from collections.abc import Iterable
4from dataclasses import dataclass, field
5from itertools import chain
6from typing import Literal, Union, cast
8from cohere import TextAssistantMessageContentItem
9from typing_extensions import assert_never
11from .. import ModelHTTPError, result
12from .._utils import generate_tool_call_id as _generate_tool_call_id, guard_tool_call_id as _guard_tool_call_id
13from ..messages import (
14 ModelMessage,
15 ModelRequest,
16 ModelResponse,
17 ModelResponsePart,
18 RetryPromptPart,
19 SystemPromptPart,
20 TextPart,
21 ToolCallPart,
22 ToolReturnPart,
23 UserPromptPart,
24)
25from ..providers import Provider, infer_provider
26from ..settings import ModelSettings
27from ..tools import ToolDefinition
28from . import (
29 Model,
30 ModelRequestParameters,
31 check_allow_model_requests,
32)
34try:
35 from cohere import (
36 AssistantChatMessageV2,
37 AsyncClientV2,
38 ChatMessageV2,
39 ChatResponse,
40 SystemChatMessageV2,
41 ToolCallV2,
42 ToolCallV2Function,
43 ToolChatMessageV2,
44 ToolV2,
45 ToolV2Function,
46 UserChatMessageV2,
47 )
48 from cohere.core.api_error import ApiError
49 from cohere.v2.client import OMIT
50except ImportError as _import_error:
51 raise ImportError(
52 'Please install `cohere` to use the Cohere model, '
53 'you can use the `cohere` optional group — `pip install "pydantic-ai-slim[cohere]"`'
54 ) from _import_error
56LatestCohereModelNames = Literal[
57 'c4ai-aya-expanse-32b',
58 'c4ai-aya-expanse-8b',
59 'command',
60 'command-light',
61 'command-light-nightly',
62 'command-nightly',
63 'command-r',
64 'command-r-03-2024',
65 'command-r-08-2024',
66 'command-r-plus',
67 'command-r-plus-04-2024',
68 'command-r-plus-08-2024',
69 'command-r7b-12-2024',
70]
71"""Latest Cohere models."""
73CohereModelName = Union[str, LatestCohereModelNames]
74"""Possible Cohere model names.
76Since Cohere supports a variety of date-stamped models, we explicitly list the latest models but
77allow any name in the type hints.
78See [Cohere's docs](https://docs.cohere.com/v2/docs/models) for a list of all available models.
79"""
82class CohereModelSettings(ModelSettings):
83 """Settings used for a Cohere model request.
85 ALL FIELDS MUST BE `cohere_` PREFIXED SO YOU CAN MERGE THEM WITH OTHER MODELS.
86 """
88 # This class is a placeholder for any future cohere-specific settings
91@dataclass(init=False)
92class CohereModel(Model):
93 """A model that uses the Cohere API.
95 Internally, this uses the [Cohere Python client](
96 https://github.com/cohere-ai/cohere-python) to interact with the API.
98 Apart from `__init__`, all methods are private or match those of the base class.
99 """
101 client: AsyncClientV2 = field(repr=False)
103 _model_name: CohereModelName = field(repr=False)
104 _system: str = field(default='cohere', repr=False)
106 def __init__(
107 self,
108 model_name: CohereModelName,
109 *,
110 provider: Literal['cohere'] | Provider[AsyncClientV2] = 'cohere',
111 ):
112 """Initialize an Cohere model.
114 Args:
115 model_name: The name of the Cohere model to use. List of model names
116 available [here](https://docs.cohere.com/docs/models#command).
117 provider: The provider to use for authentication and API access. Can be either the string
118 'cohere' or an instance of `Provider[AsyncClientV2]`. If not provided, a new provider will be
119 created using the other parameters.
120 """
121 self._model_name: CohereModelName = model_name
123 if isinstance(provider, str):
124 provider = infer_provider(provider)
125 self.client = provider.client
127 @property
128 def base_url(self) -> str:
129 client_wrapper = self.client._client_wrapper # type: ignore
130 return str(client_wrapper.get_base_url())
132 async def request(
133 self,
134 messages: list[ModelMessage],
135 model_settings: ModelSettings | None,
136 model_request_parameters: ModelRequestParameters,
137 ) -> tuple[ModelResponse, result.Usage]:
138 check_allow_model_requests()
139 response = await self._chat(messages, cast(CohereModelSettings, model_settings or {}), model_request_parameters)
140 return self._process_response(response), _map_usage(response)
142 @property
143 def model_name(self) -> CohereModelName:
144 """The model name."""
145 return self._model_name
147 @property
148 def system(self) -> str:
149 """The system / model provider."""
150 return self._system
152 async def _chat(
153 self,
154 messages: list[ModelMessage],
155 model_settings: CohereModelSettings,
156 model_request_parameters: ModelRequestParameters,
157 ) -> ChatResponse:
158 tools = self._get_tools(model_request_parameters)
159 cohere_messages = list(chain(*(self._map_message(m) for m in messages)))
160 try:
161 return await self.client.chat(
162 model=self._model_name,
163 messages=cohere_messages,
164 tools=tools or OMIT,
165 max_tokens=model_settings.get('max_tokens', OMIT),
166 temperature=model_settings.get('temperature', OMIT),
167 p=model_settings.get('top_p', OMIT),
168 seed=model_settings.get('seed', OMIT),
169 presence_penalty=model_settings.get('presence_penalty', OMIT),
170 frequency_penalty=model_settings.get('frequency_penalty', OMIT),
171 )
172 except ApiError as e:
173 if (status_code := e.status_code) and status_code >= 400: 173 ↛ 175line 173 didn't jump to line 175 because the condition on line 173 was always true
174 raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e
175 raise
177 def _process_response(self, response: ChatResponse) -> ModelResponse:
178 """Process a non-streamed response, and prepare a message to return."""
179 parts: list[ModelResponsePart] = []
180 if response.message.content is not None and len(response.message.content) > 0:
181 # While Cohere's API returns a list, it only does that for future proofing
182 # and currently only one item is being returned.
183 choice = response.message.content[0]
184 parts.append(TextPart(choice.text))
185 for c in response.message.tool_calls or []:
186 if c.function and c.function.name and c.function.arguments: 186 ↛ 185line 186 didn't jump to line 185 because the condition on line 186 was always true
187 parts.append(
188 ToolCallPart(
189 tool_name=c.function.name,
190 args=c.function.arguments,
191 tool_call_id=c.id or _generate_tool_call_id(),
192 )
193 )
194 return ModelResponse(parts=parts, model_name=self._model_name)
196 def _map_message(self, message: ModelMessage) -> Iterable[ChatMessageV2]:
197 """Just maps a `pydantic_ai.Message` to a `cohere.ChatMessageV2`."""
198 if isinstance(message, ModelRequest):
199 yield from self._map_user_message(message)
200 elif isinstance(message, ModelResponse):
201 texts: list[str] = []
202 tool_calls: list[ToolCallV2] = []
203 for item in message.parts:
204 if isinstance(item, TextPart):
205 texts.append(item.content)
206 elif isinstance(item, ToolCallPart):
207 tool_calls.append(self._map_tool_call(item))
208 else:
209 assert_never(item)
210 message_param = AssistantChatMessageV2(role='assistant')
211 if texts:
212 message_param.content = [TextAssistantMessageContentItem(text='\n\n'.join(texts))]
213 if tool_calls:
214 message_param.tool_calls = tool_calls
215 yield message_param
216 else:
217 assert_never(message)
219 def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[ToolV2]:
220 tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools]
221 if model_request_parameters.result_tools:
222 tools += [self._map_tool_definition(r) for r in model_request_parameters.result_tools]
223 return tools
225 @staticmethod
226 def _map_tool_call(t: ToolCallPart) -> ToolCallV2:
227 return ToolCallV2(
228 id=_guard_tool_call_id(t=t),
229 type='function',
230 function=ToolCallV2Function(
231 name=t.tool_name,
232 arguments=t.args_as_json_str(),
233 ),
234 )
236 @staticmethod
237 def _map_tool_definition(f: ToolDefinition) -> ToolV2:
238 return ToolV2(
239 type='function',
240 function=ToolV2Function(
241 name=f.name,
242 description=f.description,
243 parameters=f.parameters_json_schema,
244 ),
245 )
247 @classmethod
248 def _map_user_message(cls, message: ModelRequest) -> Iterable[ChatMessageV2]:
249 for part in message.parts:
250 if isinstance(part, SystemPromptPart):
251 yield SystemChatMessageV2(role='system', content=part.content)
252 elif isinstance(part, UserPromptPart):
253 if isinstance(part.content, str):
254 yield UserChatMessageV2(role='user', content=part.content)
255 else:
256 raise RuntimeError('Cohere does not yet support multi-modal inputs.')
257 elif isinstance(part, ToolReturnPart):
258 yield ToolChatMessageV2(
259 role='tool',
260 tool_call_id=_guard_tool_call_id(t=part),
261 content=part.model_response_str(),
262 )
263 elif isinstance(part, RetryPromptPart):
264 if part.tool_name is None: 264 ↛ 265line 264 didn't jump to line 265 because the condition on line 264 was never true
265 yield UserChatMessageV2(role='user', content=part.model_response())
266 else:
267 yield ToolChatMessageV2(
268 role='tool',
269 tool_call_id=_guard_tool_call_id(t=part),
270 content=part.model_response(),
271 )
272 else:
273 assert_never(part)
276def _map_usage(response: ChatResponse) -> result.Usage:
277 usage = response.usage
278 if usage is None:
279 return result.Usage()
280 else:
281 details: dict[str, int] = {}
282 if usage.billed_units is not None:
283 if usage.billed_units.input_tokens: 283 ↛ 285line 283 didn't jump to line 285 because the condition on line 283 was always true
284 details['input_tokens'] = int(usage.billed_units.input_tokens)
285 if usage.billed_units.output_tokens: 285 ↛ 287line 285 didn't jump to line 287 because the condition on line 285 was always true
286 details['output_tokens'] = int(usage.billed_units.output_tokens)
287 if usage.billed_units.search_units: 287 ↛ 288line 287 didn't jump to line 288 because the condition on line 287 was never true
288 details['search_units'] = int(usage.billed_units.search_units)
289 if usage.billed_units.classifications: 289 ↛ 290line 289 didn't jump to line 290 because the condition on line 289 was never true
290 details['classifications'] = int(usage.billed_units.classifications)
292 request_tokens = int(usage.tokens.input_tokens) if usage.tokens and usage.tokens.input_tokens else None
293 response_tokens = int(usage.tokens.output_tokens) if usage.tokens and usage.tokens.output_tokens else None
294 return result.Usage(
295 request_tokens=request_tokens,
296 response_tokens=response_tokens,
297 total_tokens=(request_tokens or 0) + (response_tokens or 0),
298 details=details,
299 )