Coverage for pydantic_ai_slim/pydantic_ai/models/cohere.py: 92.86%

112 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2025-01-25 16:43 +0000

1from __future__ import annotations as _annotations 

2 

3from collections.abc import Iterable 

4from dataclasses import dataclass, field 

5from itertools import chain 

6from typing import Literal, TypeAlias, Union, cast 

7 

8from cohere import TextAssistantMessageContentItem 

9from typing_extensions import assert_never 

10 

11from .. import result 

12from .._utils import guard_tool_call_id as _guard_tool_call_id 

13from ..messages import ( 

14 ModelMessage, 

15 ModelRequest, 

16 ModelResponse, 

17 ModelResponsePart, 

18 RetryPromptPart, 

19 SystemPromptPart, 

20 TextPart, 

21 ToolCallPart, 

22 ToolReturnPart, 

23 UserPromptPart, 

24) 

25from ..settings import ModelSettings 

26from ..tools import ToolDefinition 

27from . import ( 

28 AgentModel, 

29 Model, 

30 check_allow_model_requests, 

31) 

32 

33try: 

34 from cohere import ( 

35 AssistantChatMessageV2, 

36 AsyncClientV2, 

37 ChatMessageV2, 

38 ChatResponse, 

39 SystemChatMessageV2, 

40 ToolCallV2, 

41 ToolCallV2Function, 

42 ToolChatMessageV2, 

43 ToolV2, 

44 ToolV2Function, 

45 UserChatMessageV2, 

46 ) 

47 from cohere.v2.client import OMIT 

48except ImportError as _import_error: 

49 raise ImportError( 

50 'Please install `cohere` to use the Cohere model, ' 

51 "you can use the `cohere` optional group — `pip install 'pydantic-ai-slim[cohere]'`" 

52 ) from _import_error 

53 

54CohereModelName: TypeAlias = Union[ 

55 str, 

56 Literal[ 

57 'c4ai-aya-expanse-32b', 

58 'c4ai-aya-expanse-8b', 

59 'command', 

60 'command-light', 

61 'command-light-nightly', 

62 'command-nightly', 

63 'command-r', 

64 'command-r-03-2024', 

65 'command-r-08-2024', 

66 'command-r-plus', 

67 'command-r-plus-04-2024', 

68 'command-r-plus-08-2024', 

69 'command-r7b-12-2024', 

70 ], 

71] 

72 

73 

74class CohereModelSettings(ModelSettings): 

75 """Settings used for a Cohere model request.""" 

76 

77 # This class is a placeholder for any future cohere-specific settings 

78 

79 

80@dataclass(init=False) 

81class CohereModel(Model): 

82 """A model that uses the Cohere API. 

83 

84 Internally, this uses the [Cohere Python client]( 

85 https://github.com/cohere-ai/cohere-python) to interact with the API. 

86 

87 Apart from `__init__`, all methods are private or match those of the base class. 

88 """ 

89 

90 model_name: CohereModelName 

91 client: AsyncClientV2 = field(repr=False) 

92 

93 def __init__( 

94 self, 

95 model_name: CohereModelName, 

96 *, 

97 api_key: str | None = None, 

98 cohere_client: AsyncClientV2 | None = None, 

99 ): 

100 """Initialize an Cohere model. 

101 

102 Args: 

103 model_name: The name of the Cohere model to use. List of model names 

104 available [here](https://docs.cohere.com/docs/models#command). 

105 api_key: The API key to use for authentication, if not provided, the 

106 `COHERE_API_KEY` environment variable will be used if available. 

107 cohere_client: An existing Cohere async client to use. If provided, 

108 `api_key` must be `None`. 

109 """ 

110 self.model_name: CohereModelName = model_name 

111 if cohere_client is not None: 

112 assert api_key is None, 'Cannot provide both `cohere_client` and `api_key`' 

113 self.client = cohere_client 

114 else: 

115 self.client = AsyncClientV2(api_key=api_key) # type: ignore 

116 

117 async def agent_model( 

118 self, 

119 *, 

120 function_tools: list[ToolDefinition], 

121 allow_text_result: bool, 

122 result_tools: list[ToolDefinition], 

123 ) -> AgentModel: 

124 check_allow_model_requests() 

125 tools = [self._map_tool_definition(r) for r in function_tools] 

126 if result_tools: 

127 tools += [self._map_tool_definition(r) for r in result_tools] 

128 return CohereAgentModel( 

129 self.client, 

130 self.model_name, 

131 allow_text_result, 

132 tools, 

133 ) 

134 

135 def name(self) -> str: 

136 return f'cohere:{self.model_name}' 

137 

138 @staticmethod 

139 def _map_tool_definition(f: ToolDefinition) -> ToolV2: 

140 return ToolV2( 

141 type='function', 

142 function=ToolV2Function( 

143 name=f.name, 

144 description=f.description, 

145 parameters=f.parameters_json_schema, 

146 ), 

147 ) 

148 

149 

150@dataclass 

151class CohereAgentModel(AgentModel): 

152 """Implementation of `AgentModel` for Cohere models.""" 

153 

154 client: AsyncClientV2 

155 model_name: CohereModelName 

156 allow_text_result: bool 

157 tools: list[ToolV2] 

158 

159 async def request( 

160 self, messages: list[ModelMessage], model_settings: ModelSettings | None 

161 ) -> tuple[ModelResponse, result.Usage]: 

162 response = await self._chat(messages, cast(CohereModelSettings, model_settings or {})) 

163 return self._process_response(response), _map_usage(response) 

164 

165 async def _chat( 

166 self, 

167 messages: list[ModelMessage], 

168 model_settings: CohereModelSettings, 

169 ) -> ChatResponse: 

170 cohere_messages = list(chain(*(self._map_message(m) for m in messages))) 

171 return await self.client.chat( 

172 model=self.model_name, 

173 messages=cohere_messages, 

174 tools=self.tools or OMIT, 

175 max_tokens=model_settings.get('max_tokens', OMIT), 

176 temperature=model_settings.get('temperature', OMIT), 

177 p=model_settings.get('top_p', OMIT), 

178 seed=model_settings.get('seed', OMIT), 

179 presence_penalty=model_settings.get('presence_penalty', OMIT), 

180 frequency_penalty=model_settings.get('frequency_penalty', OMIT), 

181 ) 

182 

183 def _process_response(self, response: ChatResponse) -> ModelResponse: 

184 """Process a non-streamed response, and prepare a message to return.""" 

185 parts: list[ModelResponsePart] = [] 

186 if response.message.content is not None and len(response.message.content) > 0: 

187 # While Cohere's API returns a list, it only does that for future proofing 

188 # and currently only one item is being returned. 

189 choice = response.message.content[0] 

190 parts.append(TextPart(choice.text)) 

191 for c in response.message.tool_calls or []: 

192 if c.function and c.function.name and c.function.arguments: 192 ↛ 191line 192 didn't jump to line 191 because the condition on line 192 was always true

193 parts.append( 

194 ToolCallPart( 

195 tool_name=c.function.name, 

196 args=c.function.arguments, 

197 tool_call_id=c.id, 

198 ) 

199 ) 

200 return ModelResponse(parts=parts, model_name=self.model_name) 

201 

202 @classmethod 

203 def _map_message(cls, message: ModelMessage) -> Iterable[ChatMessageV2]: 

204 """Just maps a `pydantic_ai.Message` to a `cohere.ChatMessageV2`.""" 

205 if isinstance(message, ModelRequest): 

206 yield from cls._map_user_message(message) 

207 elif isinstance(message, ModelResponse): 

208 texts: list[str] = [] 

209 tool_calls: list[ToolCallV2] = [] 

210 for item in message.parts: 

211 if isinstance(item, TextPart): 

212 texts.append(item.content) 

213 elif isinstance(item, ToolCallPart): 

214 tool_calls.append(_map_tool_call(item)) 

215 else: 

216 assert_never(item) 

217 message_param = AssistantChatMessageV2(role='assistant') 

218 if texts: 

219 message_param.content = [TextAssistantMessageContentItem(text='\n\n'.join(texts))] 

220 if tool_calls: 

221 message_param.tool_calls = tool_calls 

222 yield message_param 

223 else: 

224 assert_never(message) 

225 

226 @classmethod 

227 def _map_user_message(cls, message: ModelRequest) -> Iterable[ChatMessageV2]: 

228 for part in message.parts: 

229 if isinstance(part, SystemPromptPart): 

230 yield SystemChatMessageV2(role='system', content=part.content) 

231 elif isinstance(part, UserPromptPart): 

232 yield UserChatMessageV2(role='user', content=part.content) 

233 elif isinstance(part, ToolReturnPart): 

234 yield ToolChatMessageV2( 

235 role='tool', 

236 tool_call_id=_guard_tool_call_id(t=part, model_source='Cohere'), 

237 content=part.model_response_str(), 

238 ) 

239 elif isinstance(part, RetryPromptPart): 

240 if part.tool_name is None: 240 ↛ 241line 240 didn't jump to line 241 because the condition on line 240 was never true

241 yield UserChatMessageV2(role='user', content=part.model_response()) 

242 else: 

243 yield ToolChatMessageV2( 

244 role='tool', 

245 tool_call_id=_guard_tool_call_id(t=part, model_source='Cohere'), 

246 content=part.model_response(), 

247 ) 

248 else: 

249 assert_never(part) 

250 

251 

252def _map_tool_call(t: ToolCallPart) -> ToolCallV2: 

253 return ToolCallV2( 

254 id=_guard_tool_call_id(t=t, model_source='Cohere'), 

255 type='function', 

256 function=ToolCallV2Function( 

257 name=t.tool_name, 

258 arguments=t.args_as_json_str(), 

259 ), 

260 ) 

261 

262 

263def _map_usage(response: ChatResponse) -> result.Usage: 

264 usage = response.usage 

265 if usage is None: 

266 return result.Usage() 

267 else: 

268 details: dict[str, int] = {} 

269 if usage.billed_units is not None: 

270 if usage.billed_units.input_tokens: 270 ↛ 272line 270 didn't jump to line 272 because the condition on line 270 was always true

271 details['input_tokens'] = int(usage.billed_units.input_tokens) 

272 if usage.billed_units.output_tokens: 272 ↛ 274line 272 didn't jump to line 274 because the condition on line 272 was always true

273 details['output_tokens'] = int(usage.billed_units.output_tokens) 

274 if usage.billed_units.search_units: 274 ↛ 275line 274 didn't jump to line 275 because the condition on line 274 was never true

275 details['search_units'] = int(usage.billed_units.search_units) 

276 if usage.billed_units.classifications: 276 ↛ 277line 276 didn't jump to line 277 because the condition on line 276 was never true

277 details['classifications'] = int(usage.billed_units.classifications) 

278 

279 request_tokens = int(usage.tokens.input_tokens) if usage.tokens and usage.tokens.input_tokens else None 

280 response_tokens = int(usage.tokens.output_tokens) if usage.tokens and usage.tokens.output_tokens else None 

281 return result.Usage( 

282 request_tokens=request_tokens, 

283 response_tokens=response_tokens, 

284 total_tokens=(request_tokens or 0) + (response_tokens or 0), 

285 details=details, 

286 )