Coverage for pydantic_ai_slim/pydantic_ai/models/cohere.py: 94.15%

125 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-28 17:27 +0000

1from __future__ import annotations as _annotations 

2 

3from collections.abc import Iterable 

4from dataclasses import dataclass, field 

5from itertools import chain 

6from typing import Literal, Union, cast 

7 

8from cohere import TextAssistantMessageContentItem 

9from typing_extensions import assert_never 

10 

11from .. import ModelHTTPError, result 

12from .._utils import generate_tool_call_id as _generate_tool_call_id, guard_tool_call_id as _guard_tool_call_id 

13from ..messages import ( 

14 ModelMessage, 

15 ModelRequest, 

16 ModelResponse, 

17 ModelResponsePart, 

18 RetryPromptPart, 

19 SystemPromptPart, 

20 TextPart, 

21 ToolCallPart, 

22 ToolReturnPart, 

23 UserPromptPart, 

24) 

25from ..providers import Provider, infer_provider 

26from ..settings import ModelSettings 

27from ..tools import ToolDefinition 

28from . import ( 

29 Model, 

30 ModelRequestParameters, 

31 check_allow_model_requests, 

32) 

33 

34try: 

35 from cohere import ( 

36 AssistantChatMessageV2, 

37 AsyncClientV2, 

38 ChatMessageV2, 

39 ChatResponse, 

40 SystemChatMessageV2, 

41 ToolCallV2, 

42 ToolCallV2Function, 

43 ToolChatMessageV2, 

44 ToolV2, 

45 ToolV2Function, 

46 UserChatMessageV2, 

47 ) 

48 from cohere.core.api_error import ApiError 

49 from cohere.v2.client import OMIT 

50except ImportError as _import_error: 

51 raise ImportError( 

52 'Please install `cohere` to use the Cohere model, ' 

53 'you can use the `cohere` optional group — `pip install "pydantic-ai-slim[cohere]"`' 

54 ) from _import_error 

55 

56LatestCohereModelNames = Literal[ 

57 'c4ai-aya-expanse-32b', 

58 'c4ai-aya-expanse-8b', 

59 'command', 

60 'command-light', 

61 'command-light-nightly', 

62 'command-nightly', 

63 'command-r', 

64 'command-r-03-2024', 

65 'command-r-08-2024', 

66 'command-r-plus', 

67 'command-r-plus-04-2024', 

68 'command-r-plus-08-2024', 

69 'command-r7b-12-2024', 

70] 

71"""Latest Cohere models.""" 

72 

73CohereModelName = Union[str, LatestCohereModelNames] 

74"""Possible Cohere model names. 

75 

76Since Cohere supports a variety of date-stamped models, we explicitly list the latest models but 

77allow any name in the type hints. 

78See [Cohere's docs](https://docs.cohere.com/v2/docs/models) for a list of all available models. 

79""" 

80 

81 

82class CohereModelSettings(ModelSettings): 

83 """Settings used for a Cohere model request. 

84 

85 ALL FIELDS MUST BE `cohere_` PREFIXED SO YOU CAN MERGE THEM WITH OTHER MODELS. 

86 """ 

87 

88 # This class is a placeholder for any future cohere-specific settings 

89 

90 

91@dataclass(init=False) 

92class CohereModel(Model): 

93 """A model that uses the Cohere API. 

94 

95 Internally, this uses the [Cohere Python client]( 

96 https://github.com/cohere-ai/cohere-python) to interact with the API. 

97 

98 Apart from `__init__`, all methods are private or match those of the base class. 

99 """ 

100 

101 client: AsyncClientV2 = field(repr=False) 

102 

103 _model_name: CohereModelName = field(repr=False) 

104 _system: str = field(default='cohere', repr=False) 

105 

106 def __init__( 

107 self, 

108 model_name: CohereModelName, 

109 *, 

110 provider: Literal['cohere'] | Provider[AsyncClientV2] = 'cohere', 

111 ): 

112 """Initialize an Cohere model. 

113 

114 Args: 

115 model_name: The name of the Cohere model to use. List of model names 

116 available [here](https://docs.cohere.com/docs/models#command). 

117 provider: The provider to use for authentication and API access. Can be either the string 

118 'cohere' or an instance of `Provider[AsyncClientV2]`. If not provided, a new provider will be 

119 created using the other parameters. 

120 """ 

121 self._model_name: CohereModelName = model_name 

122 

123 if isinstance(provider, str): 

124 provider = infer_provider(provider) 

125 self.client = provider.client 

126 

127 @property 

128 def base_url(self) -> str: 

129 client_wrapper = self.client._client_wrapper # type: ignore 

130 return str(client_wrapper.get_base_url()) 

131 

132 async def request( 

133 self, 

134 messages: list[ModelMessage], 

135 model_settings: ModelSettings | None, 

136 model_request_parameters: ModelRequestParameters, 

137 ) -> tuple[ModelResponse, result.Usage]: 

138 check_allow_model_requests() 

139 response = await self._chat(messages, cast(CohereModelSettings, model_settings or {}), model_request_parameters) 

140 return self._process_response(response), _map_usage(response) 

141 

142 @property 

143 def model_name(self) -> CohereModelName: 

144 """The model name.""" 

145 return self._model_name 

146 

147 @property 

148 def system(self) -> str: 

149 """The system / model provider.""" 

150 return self._system 

151 

152 async def _chat( 

153 self, 

154 messages: list[ModelMessage], 

155 model_settings: CohereModelSettings, 

156 model_request_parameters: ModelRequestParameters, 

157 ) -> ChatResponse: 

158 tools = self._get_tools(model_request_parameters) 

159 cohere_messages = list(chain(*(self._map_message(m) for m in messages))) 

160 try: 

161 return await self.client.chat( 

162 model=self._model_name, 

163 messages=cohere_messages, 

164 tools=tools or OMIT, 

165 max_tokens=model_settings.get('max_tokens', OMIT), 

166 temperature=model_settings.get('temperature', OMIT), 

167 p=model_settings.get('top_p', OMIT), 

168 seed=model_settings.get('seed', OMIT), 

169 presence_penalty=model_settings.get('presence_penalty', OMIT), 

170 frequency_penalty=model_settings.get('frequency_penalty', OMIT), 

171 ) 

172 except ApiError as e: 

173 if (status_code := e.status_code) and status_code >= 400: 173 ↛ 175line 173 didn't jump to line 175 because the condition on line 173 was always true

174 raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e 

175 raise 

176 

177 def _process_response(self, response: ChatResponse) -> ModelResponse: 

178 """Process a non-streamed response, and prepare a message to return.""" 

179 parts: list[ModelResponsePart] = [] 

180 if response.message.content is not None and len(response.message.content) > 0: 

181 # While Cohere's API returns a list, it only does that for future proofing 

182 # and currently only one item is being returned. 

183 choice = response.message.content[0] 

184 parts.append(TextPart(choice.text)) 

185 for c in response.message.tool_calls or []: 

186 if c.function and c.function.name and c.function.arguments: 186 ↛ 185line 186 didn't jump to line 185 because the condition on line 186 was always true

187 parts.append( 

188 ToolCallPart( 

189 tool_name=c.function.name, 

190 args=c.function.arguments, 

191 tool_call_id=c.id or _generate_tool_call_id(), 

192 ) 

193 ) 

194 return ModelResponse(parts=parts, model_name=self._model_name) 

195 

196 def _map_message(self, message: ModelMessage) -> Iterable[ChatMessageV2]: 

197 """Just maps a `pydantic_ai.Message` to a `cohere.ChatMessageV2`.""" 

198 if isinstance(message, ModelRequest): 

199 yield from self._map_user_message(message) 

200 elif isinstance(message, ModelResponse): 

201 texts: list[str] = [] 

202 tool_calls: list[ToolCallV2] = [] 

203 for item in message.parts: 

204 if isinstance(item, TextPart): 

205 texts.append(item.content) 

206 elif isinstance(item, ToolCallPart): 

207 tool_calls.append(self._map_tool_call(item)) 

208 else: 

209 assert_never(item) 

210 message_param = AssistantChatMessageV2(role='assistant') 

211 if texts: 

212 message_param.content = [TextAssistantMessageContentItem(text='\n\n'.join(texts))] 

213 if tool_calls: 

214 message_param.tool_calls = tool_calls 

215 yield message_param 

216 else: 

217 assert_never(message) 

218 

219 def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[ToolV2]: 

220 tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools] 

221 if model_request_parameters.result_tools: 

222 tools += [self._map_tool_definition(r) for r in model_request_parameters.result_tools] 

223 return tools 

224 

225 @staticmethod 

226 def _map_tool_call(t: ToolCallPart) -> ToolCallV2: 

227 return ToolCallV2( 

228 id=_guard_tool_call_id(t=t), 

229 type='function', 

230 function=ToolCallV2Function( 

231 name=t.tool_name, 

232 arguments=t.args_as_json_str(), 

233 ), 

234 ) 

235 

236 @staticmethod 

237 def _map_tool_definition(f: ToolDefinition) -> ToolV2: 

238 return ToolV2( 

239 type='function', 

240 function=ToolV2Function( 

241 name=f.name, 

242 description=f.description, 

243 parameters=f.parameters_json_schema, 

244 ), 

245 ) 

246 

247 @classmethod 

248 def _map_user_message(cls, message: ModelRequest) -> Iterable[ChatMessageV2]: 

249 for part in message.parts: 

250 if isinstance(part, SystemPromptPart): 

251 yield SystemChatMessageV2(role='system', content=part.content) 

252 elif isinstance(part, UserPromptPart): 

253 if isinstance(part.content, str): 

254 yield UserChatMessageV2(role='user', content=part.content) 

255 else: 

256 raise RuntimeError('Cohere does not yet support multi-modal inputs.') 

257 elif isinstance(part, ToolReturnPart): 

258 yield ToolChatMessageV2( 

259 role='tool', 

260 tool_call_id=_guard_tool_call_id(t=part), 

261 content=part.model_response_str(), 

262 ) 

263 elif isinstance(part, RetryPromptPart): 

264 if part.tool_name is None: 264 ↛ 265line 264 didn't jump to line 265 because the condition on line 264 was never true

265 yield UserChatMessageV2(role='user', content=part.model_response()) 

266 else: 

267 yield ToolChatMessageV2( 

268 role='tool', 

269 tool_call_id=_guard_tool_call_id(t=part), 

270 content=part.model_response(), 

271 ) 

272 else: 

273 assert_never(part) 

274 

275 

276def _map_usage(response: ChatResponse) -> result.Usage: 

277 usage = response.usage 

278 if usage is None: 

279 return result.Usage() 

280 else: 

281 details: dict[str, int] = {} 

282 if usage.billed_units is not None: 

283 if usage.billed_units.input_tokens: 283 ↛ 285line 283 didn't jump to line 285 because the condition on line 283 was always true

284 details['input_tokens'] = int(usage.billed_units.input_tokens) 

285 if usage.billed_units.output_tokens: 285 ↛ 287line 285 didn't jump to line 287 because the condition on line 285 was always true

286 details['output_tokens'] = int(usage.billed_units.output_tokens) 

287 if usage.billed_units.search_units: 287 ↛ 288line 287 didn't jump to line 288 because the condition on line 287 was never true

288 details['search_units'] = int(usage.billed_units.search_units) 

289 if usage.billed_units.classifications: 289 ↛ 290line 289 didn't jump to line 290 because the condition on line 289 was never true

290 details['classifications'] = int(usage.billed_units.classifications) 

291 

292 request_tokens = int(usage.tokens.input_tokens) if usage.tokens and usage.tokens.input_tokens else None 

293 response_tokens = int(usage.tokens.output_tokens) if usage.tokens and usage.tokens.output_tokens else None 

294 return result.Usage( 

295 request_tokens=request_tokens, 

296 response_tokens=response_tokens, 

297 total_tokens=(request_tokens or 0) + (response_tokens or 0), 

298 details=details, 

299 )