Coverage for pydantic_ai_slim/pydantic_ai/models/cohere.py: 93.04%

116 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2025-01-30 19:21 +0000

1from __future__ import annotations as _annotations 

2 

3from collections.abc import Iterable 

4from dataclasses import dataclass, field 

5from itertools import chain 

6from typing import Literal, Union, cast 

7 

8from cohere import TextAssistantMessageContentItem 

9from httpx import AsyncClient as AsyncHTTPClient 

10from typing_extensions import assert_never 

11 

12from .. import result 

13from .._utils import guard_tool_call_id as _guard_tool_call_id 

14from ..messages import ( 

15 ModelMessage, 

16 ModelRequest, 

17 ModelResponse, 

18 ModelResponsePart, 

19 RetryPromptPart, 

20 SystemPromptPart, 

21 TextPart, 

22 ToolCallPart, 

23 ToolReturnPart, 

24 UserPromptPart, 

25) 

26from ..settings import ModelSettings 

27from ..tools import ToolDefinition 

28from . import ( 

29 AgentModel, 

30 Model, 

31 check_allow_model_requests, 

32) 

33 

34try: 

35 from cohere import ( 

36 AssistantChatMessageV2, 

37 AsyncClientV2, 

38 ChatMessageV2, 

39 ChatResponse, 

40 SystemChatMessageV2, 

41 ToolCallV2, 

42 ToolCallV2Function, 

43 ToolChatMessageV2, 

44 ToolV2, 

45 ToolV2Function, 

46 UserChatMessageV2, 

47 ) 

48 from cohere.v2.client import OMIT 

49except ImportError as _import_error: 

50 raise ImportError( 

51 'Please install `cohere` to use the Cohere model, ' 

52 "you can use the `cohere` optional group — `pip install 'pydantic-ai-slim[cohere]'`" 

53 ) from _import_error 

54 

55NamedCohereModels = Literal[ 

56 'c4ai-aya-expanse-32b', 

57 'c4ai-aya-expanse-8b', 

58 'command', 

59 'command-light', 

60 'command-light-nightly', 

61 'command-nightly', 

62 'command-r', 

63 'command-r-03-2024', 

64 'command-r-08-2024', 

65 'command-r-plus', 

66 'command-r-plus-04-2024', 

67 'command-r-plus-08-2024', 

68 'command-r7b-12-2024', 

69] 

70"""Latest / most popular named Cohere models.""" 

71 

72CohereModelName = Union[NamedCohereModels, str] 

73 

74 

75class CohereModelSettings(ModelSettings): 

76 """Settings used for a Cohere model request.""" 

77 

78 # This class is a placeholder for any future cohere-specific settings 

79 

80 

81@dataclass(init=False) 

82class CohereModel(Model): 

83 """A model that uses the Cohere API. 

84 

85 Internally, this uses the [Cohere Python client]( 

86 https://github.com/cohere-ai/cohere-python) to interact with the API. 

87 

88 Apart from `__init__`, all methods are private or match those of the base class. 

89 """ 

90 

91 model_name: CohereModelName 

92 client: AsyncClientV2 = field(repr=False) 

93 

94 def __init__( 

95 self, 

96 model_name: CohereModelName, 

97 *, 

98 api_key: str | None = None, 

99 cohere_client: AsyncClientV2 | None = None, 

100 http_client: AsyncHTTPClient | None = None, 

101 ): 

102 """Initialize an Cohere model. 

103 

104 Args: 

105 model_name: The name of the Cohere model to use. List of model names 

106 available [here](https://docs.cohere.com/docs/models#command). 

107 api_key: The API key to use for authentication, if not provided, the 

108 `CO_API_KEY` environment variable will be used if available. 

109 cohere_client: An existing Cohere async client to use. If provided, 

110 `api_key` and `http_client` must be `None`. 

111 http_client: An existing `httpx.AsyncClient` to use for making HTTP requests. 

112 """ 

113 self.model_name: CohereModelName = model_name 

114 if cohere_client is not None: 

115 assert http_client is None, 'Cannot provide both `cohere_client` and `http_client`' 

116 assert api_key is None, 'Cannot provide both `cohere_client` and `api_key`' 

117 self.client = cohere_client 

118 else: 

119 self.client = AsyncClientV2(api_key=api_key, httpx_client=http_client) # type: ignore 

120 

121 async def agent_model( 

122 self, 

123 *, 

124 function_tools: list[ToolDefinition], 

125 allow_text_result: bool, 

126 result_tools: list[ToolDefinition], 

127 ) -> AgentModel: 

128 check_allow_model_requests() 

129 tools = [self._map_tool_definition(r) for r in function_tools] 

130 if result_tools: 

131 tools += [self._map_tool_definition(r) for r in result_tools] 

132 return CohereAgentModel( 

133 self.client, 

134 self.model_name, 

135 allow_text_result, 

136 tools, 

137 ) 

138 

139 def name(self) -> str: 

140 return f'cohere:{self.model_name}' 

141 

142 @staticmethod 

143 def _map_tool_definition(f: ToolDefinition) -> ToolV2: 

144 return ToolV2( 

145 type='function', 

146 function=ToolV2Function( 

147 name=f.name, 

148 description=f.description, 

149 parameters=f.parameters_json_schema, 

150 ), 

151 ) 

152 

153 

154@dataclass 

155class CohereAgentModel(AgentModel): 

156 """Implementation of `AgentModel` for Cohere models.""" 

157 

158 client: AsyncClientV2 

159 model_name: CohereModelName 

160 allow_text_result: bool 

161 tools: list[ToolV2] 

162 

163 async def request( 

164 self, messages: list[ModelMessage], model_settings: ModelSettings | None 

165 ) -> tuple[ModelResponse, result.Usage]: 

166 response = await self._chat(messages, cast(CohereModelSettings, model_settings or {})) 

167 return self._process_response(response), _map_usage(response) 

168 

169 async def _chat( 

170 self, 

171 messages: list[ModelMessage], 

172 model_settings: CohereModelSettings, 

173 ) -> ChatResponse: 

174 cohere_messages = list(chain(*(self._map_message(m) for m in messages))) 

175 return await self.client.chat( 

176 model=self.model_name, 

177 messages=cohere_messages, 

178 tools=self.tools or OMIT, 

179 max_tokens=model_settings.get('max_tokens', OMIT), 

180 temperature=model_settings.get('temperature', OMIT), 

181 p=model_settings.get('top_p', OMIT), 

182 seed=model_settings.get('seed', OMIT), 

183 presence_penalty=model_settings.get('presence_penalty', OMIT), 

184 frequency_penalty=model_settings.get('frequency_penalty', OMIT), 

185 ) 

186 

187 def _process_response(self, response: ChatResponse) -> ModelResponse: 

188 """Process a non-streamed response, and prepare a message to return.""" 

189 parts: list[ModelResponsePart] = [] 

190 if response.message.content is not None and len(response.message.content) > 0: 

191 # While Cohere's API returns a list, it only does that for future proofing 

192 # and currently only one item is being returned. 

193 choice = response.message.content[0] 

194 parts.append(TextPart(choice.text)) 

195 for c in response.message.tool_calls or []: 

196 if c.function and c.function.name and c.function.arguments: 196 ↛ 195line 196 didn't jump to line 195 because the condition on line 196 was always true

197 parts.append( 

198 ToolCallPart( 

199 tool_name=c.function.name, 

200 args=c.function.arguments, 

201 tool_call_id=c.id, 

202 ) 

203 ) 

204 return ModelResponse(parts=parts, model_name=self.model_name) 

205 

206 @classmethod 

207 def _map_message(cls, message: ModelMessage) -> Iterable[ChatMessageV2]: 

208 """Just maps a `pydantic_ai.Message` to a `cohere.ChatMessageV2`.""" 

209 if isinstance(message, ModelRequest): 

210 yield from cls._map_user_message(message) 

211 elif isinstance(message, ModelResponse): 

212 texts: list[str] = [] 

213 tool_calls: list[ToolCallV2] = [] 

214 for item in message.parts: 

215 if isinstance(item, TextPart): 

216 texts.append(item.content) 

217 elif isinstance(item, ToolCallPart): 

218 tool_calls.append(_map_tool_call(item)) 

219 else: 

220 assert_never(item) 

221 message_param = AssistantChatMessageV2(role='assistant') 

222 if texts: 

223 message_param.content = [TextAssistantMessageContentItem(text='\n\n'.join(texts))] 

224 if tool_calls: 

225 message_param.tool_calls = tool_calls 

226 yield message_param 

227 else: 

228 assert_never(message) 

229 

230 @classmethod 

231 def _map_user_message(cls, message: ModelRequest) -> Iterable[ChatMessageV2]: 

232 for part in message.parts: 

233 if isinstance(part, SystemPromptPart): 

234 yield SystemChatMessageV2(role='system', content=part.content) 

235 elif isinstance(part, UserPromptPart): 

236 yield UserChatMessageV2(role='user', content=part.content) 

237 elif isinstance(part, ToolReturnPart): 

238 yield ToolChatMessageV2( 

239 role='tool', 

240 tool_call_id=_guard_tool_call_id(t=part, model_source='Cohere'), 

241 content=part.model_response_str(), 

242 ) 

243 elif isinstance(part, RetryPromptPart): 

244 if part.tool_name is None: 244 ↛ 245line 244 didn't jump to line 245 because the condition on line 244 was never true

245 yield UserChatMessageV2(role='user', content=part.model_response()) 

246 else: 

247 yield ToolChatMessageV2( 

248 role='tool', 

249 tool_call_id=_guard_tool_call_id(t=part, model_source='Cohere'), 

250 content=part.model_response(), 

251 ) 

252 else: 

253 assert_never(part) 

254 

255 

256def _map_tool_call(t: ToolCallPart) -> ToolCallV2: 

257 return ToolCallV2( 

258 id=_guard_tool_call_id(t=t, model_source='Cohere'), 

259 type='function', 

260 function=ToolCallV2Function( 

261 name=t.tool_name, 

262 arguments=t.args_as_json_str(), 

263 ), 

264 ) 

265 

266 

267def _map_usage(response: ChatResponse) -> result.Usage: 

268 usage = response.usage 

269 if usage is None: 

270 return result.Usage() 

271 else: 

272 details: dict[str, int] = {} 

273 if usage.billed_units is not None: 

274 if usage.billed_units.input_tokens: 274 ↛ 276line 274 didn't jump to line 276 because the condition on line 274 was always true

275 details['input_tokens'] = int(usage.billed_units.input_tokens) 

276 if usage.billed_units.output_tokens: 276 ↛ 278line 276 didn't jump to line 278 because the condition on line 276 was always true

277 details['output_tokens'] = int(usage.billed_units.output_tokens) 

278 if usage.billed_units.search_units: 278 ↛ 279line 278 didn't jump to line 279 because the condition on line 278 was never true

279 details['search_units'] = int(usage.billed_units.search_units) 

280 if usage.billed_units.classifications: 280 ↛ 281line 280 didn't jump to line 281 because the condition on line 280 was never true

281 details['classifications'] = int(usage.billed_units.classifications) 

282 

283 request_tokens = int(usage.tokens.input_tokens) if usage.tokens and usage.tokens.input_tokens else None 

284 response_tokens = int(usage.tokens.output_tokens) if usage.tokens and usage.tokens.output_tokens else None 

285 return result.Usage( 

286 request_tokens=request_tokens, 

287 response_tokens=response_tokens, 

288 total_tokens=(request_tokens or 0) + (response_tokens or 0), 

289 details=details, 

290 )