Coverage for pydantic_ai_slim/pydantic_ai/models/cohere.py: 92.86%
112 statements
« prev ^ index » next coverage.py v7.6.7, created at 2025-01-25 16:43 +0000
« prev ^ index » next coverage.py v7.6.7, created at 2025-01-25 16:43 +0000
1from __future__ import annotations as _annotations
3from collections.abc import Iterable
4from dataclasses import dataclass, field
5from itertools import chain
6from typing import Literal, TypeAlias, Union, cast
8from cohere import TextAssistantMessageContentItem
9from typing_extensions import assert_never
11from .. import result
12from .._utils import guard_tool_call_id as _guard_tool_call_id
13from ..messages import (
14 ModelMessage,
15 ModelRequest,
16 ModelResponse,
17 ModelResponsePart,
18 RetryPromptPart,
19 SystemPromptPart,
20 TextPart,
21 ToolCallPart,
22 ToolReturnPart,
23 UserPromptPart,
24)
25from ..settings import ModelSettings
26from ..tools import ToolDefinition
27from . import (
28 AgentModel,
29 Model,
30 check_allow_model_requests,
31)
33try:
34 from cohere import (
35 AssistantChatMessageV2,
36 AsyncClientV2,
37 ChatMessageV2,
38 ChatResponse,
39 SystemChatMessageV2,
40 ToolCallV2,
41 ToolCallV2Function,
42 ToolChatMessageV2,
43 ToolV2,
44 ToolV2Function,
45 UserChatMessageV2,
46 )
47 from cohere.v2.client import OMIT
48except ImportError as _import_error:
49 raise ImportError(
50 'Please install `cohere` to use the Cohere model, '
51 "you can use the `cohere` optional group — `pip install 'pydantic-ai-slim[cohere]'`"
52 ) from _import_error
54CohereModelName: TypeAlias = Union[
55 str,
56 Literal[
57 'c4ai-aya-expanse-32b',
58 'c4ai-aya-expanse-8b',
59 'command',
60 'command-light',
61 'command-light-nightly',
62 'command-nightly',
63 'command-r',
64 'command-r-03-2024',
65 'command-r-08-2024',
66 'command-r-plus',
67 'command-r-plus-04-2024',
68 'command-r-plus-08-2024',
69 'command-r7b-12-2024',
70 ],
71]
74class CohereModelSettings(ModelSettings):
75 """Settings used for a Cohere model request."""
77 # This class is a placeholder for any future cohere-specific settings
80@dataclass(init=False)
81class CohereModel(Model):
82 """A model that uses the Cohere API.
84 Internally, this uses the [Cohere Python client](
85 https://github.com/cohere-ai/cohere-python) to interact with the API.
87 Apart from `__init__`, all methods are private or match those of the base class.
88 """
90 model_name: CohereModelName
91 client: AsyncClientV2 = field(repr=False)
93 def __init__(
94 self,
95 model_name: CohereModelName,
96 *,
97 api_key: str | None = None,
98 cohere_client: AsyncClientV2 | None = None,
99 ):
100 """Initialize an Cohere model.
102 Args:
103 model_name: The name of the Cohere model to use. List of model names
104 available [here](https://docs.cohere.com/docs/models#command).
105 api_key: The API key to use for authentication, if not provided, the
106 `COHERE_API_KEY` environment variable will be used if available.
107 cohere_client: An existing Cohere async client to use. If provided,
108 `api_key` must be `None`.
109 """
110 self.model_name: CohereModelName = model_name
111 if cohere_client is not None:
112 assert api_key is None, 'Cannot provide both `cohere_client` and `api_key`'
113 self.client = cohere_client
114 else:
115 self.client = AsyncClientV2(api_key=api_key) # type: ignore
117 async def agent_model(
118 self,
119 *,
120 function_tools: list[ToolDefinition],
121 allow_text_result: bool,
122 result_tools: list[ToolDefinition],
123 ) -> AgentModel:
124 check_allow_model_requests()
125 tools = [self._map_tool_definition(r) for r in function_tools]
126 if result_tools:
127 tools += [self._map_tool_definition(r) for r in result_tools]
128 return CohereAgentModel(
129 self.client,
130 self.model_name,
131 allow_text_result,
132 tools,
133 )
135 def name(self) -> str:
136 return f'cohere:{self.model_name}'
138 @staticmethod
139 def _map_tool_definition(f: ToolDefinition) -> ToolV2:
140 return ToolV2(
141 type='function',
142 function=ToolV2Function(
143 name=f.name,
144 description=f.description,
145 parameters=f.parameters_json_schema,
146 ),
147 )
150@dataclass
151class CohereAgentModel(AgentModel):
152 """Implementation of `AgentModel` for Cohere models."""
154 client: AsyncClientV2
155 model_name: CohereModelName
156 allow_text_result: bool
157 tools: list[ToolV2]
159 async def request(
160 self, messages: list[ModelMessage], model_settings: ModelSettings | None
161 ) -> tuple[ModelResponse, result.Usage]:
162 response = await self._chat(messages, cast(CohereModelSettings, model_settings or {}))
163 return self._process_response(response), _map_usage(response)
165 async def _chat(
166 self,
167 messages: list[ModelMessage],
168 model_settings: CohereModelSettings,
169 ) -> ChatResponse:
170 cohere_messages = list(chain(*(self._map_message(m) for m in messages)))
171 return await self.client.chat(
172 model=self.model_name,
173 messages=cohere_messages,
174 tools=self.tools or OMIT,
175 max_tokens=model_settings.get('max_tokens', OMIT),
176 temperature=model_settings.get('temperature', OMIT),
177 p=model_settings.get('top_p', OMIT),
178 seed=model_settings.get('seed', OMIT),
179 presence_penalty=model_settings.get('presence_penalty', OMIT),
180 frequency_penalty=model_settings.get('frequency_penalty', OMIT),
181 )
183 def _process_response(self, response: ChatResponse) -> ModelResponse:
184 """Process a non-streamed response, and prepare a message to return."""
185 parts: list[ModelResponsePart] = []
186 if response.message.content is not None and len(response.message.content) > 0:
187 # While Cohere's API returns a list, it only does that for future proofing
188 # and currently only one item is being returned.
189 choice = response.message.content[0]
190 parts.append(TextPart(choice.text))
191 for c in response.message.tool_calls or []:
192 if c.function and c.function.name and c.function.arguments: 192 ↛ 191line 192 didn't jump to line 191 because the condition on line 192 was always true
193 parts.append(
194 ToolCallPart(
195 tool_name=c.function.name,
196 args=c.function.arguments,
197 tool_call_id=c.id,
198 )
199 )
200 return ModelResponse(parts=parts, model_name=self.model_name)
202 @classmethod
203 def _map_message(cls, message: ModelMessage) -> Iterable[ChatMessageV2]:
204 """Just maps a `pydantic_ai.Message` to a `cohere.ChatMessageV2`."""
205 if isinstance(message, ModelRequest):
206 yield from cls._map_user_message(message)
207 elif isinstance(message, ModelResponse):
208 texts: list[str] = []
209 tool_calls: list[ToolCallV2] = []
210 for item in message.parts:
211 if isinstance(item, TextPart):
212 texts.append(item.content)
213 elif isinstance(item, ToolCallPart):
214 tool_calls.append(_map_tool_call(item))
215 else:
216 assert_never(item)
217 message_param = AssistantChatMessageV2(role='assistant')
218 if texts:
219 message_param.content = [TextAssistantMessageContentItem(text='\n\n'.join(texts))]
220 if tool_calls:
221 message_param.tool_calls = tool_calls
222 yield message_param
223 else:
224 assert_never(message)
226 @classmethod
227 def _map_user_message(cls, message: ModelRequest) -> Iterable[ChatMessageV2]:
228 for part in message.parts:
229 if isinstance(part, SystemPromptPart):
230 yield SystemChatMessageV2(role='system', content=part.content)
231 elif isinstance(part, UserPromptPart):
232 yield UserChatMessageV2(role='user', content=part.content)
233 elif isinstance(part, ToolReturnPart):
234 yield ToolChatMessageV2(
235 role='tool',
236 tool_call_id=_guard_tool_call_id(t=part, model_source='Cohere'),
237 content=part.model_response_str(),
238 )
239 elif isinstance(part, RetryPromptPart):
240 if part.tool_name is None: 240 ↛ 241line 240 didn't jump to line 241 because the condition on line 240 was never true
241 yield UserChatMessageV2(role='user', content=part.model_response())
242 else:
243 yield ToolChatMessageV2(
244 role='tool',
245 tool_call_id=_guard_tool_call_id(t=part, model_source='Cohere'),
246 content=part.model_response(),
247 )
248 else:
249 assert_never(part)
252def _map_tool_call(t: ToolCallPart) -> ToolCallV2:
253 return ToolCallV2(
254 id=_guard_tool_call_id(t=t, model_source='Cohere'),
255 type='function',
256 function=ToolCallV2Function(
257 name=t.tool_name,
258 arguments=t.args_as_json_str(),
259 ),
260 )
263def _map_usage(response: ChatResponse) -> result.Usage:
264 usage = response.usage
265 if usage is None:
266 return result.Usage()
267 else:
268 details: dict[str, int] = {}
269 if usage.billed_units is not None:
270 if usage.billed_units.input_tokens: 270 ↛ 272line 270 didn't jump to line 272 because the condition on line 270 was always true
271 details['input_tokens'] = int(usage.billed_units.input_tokens)
272 if usage.billed_units.output_tokens: 272 ↛ 274line 272 didn't jump to line 274 because the condition on line 272 was always true
273 details['output_tokens'] = int(usage.billed_units.output_tokens)
274 if usage.billed_units.search_units: 274 ↛ 275line 274 didn't jump to line 275 because the condition on line 274 was never true
275 details['search_units'] = int(usage.billed_units.search_units)
276 if usage.billed_units.classifications: 276 ↛ 277line 276 didn't jump to line 277 because the condition on line 276 was never true
277 details['classifications'] = int(usage.billed_units.classifications)
279 request_tokens = int(usage.tokens.input_tokens) if usage.tokens and usage.tokens.input_tokens else None
280 response_tokens = int(usage.tokens.output_tokens) if usage.tokens and usage.tokens.output_tokens else None
281 return result.Usage(
282 request_tokens=request_tokens,
283 response_tokens=response_tokens,
284 total_tokens=(request_tokens or 0) + (response_tokens or 0),
285 details=details,
286 )