Coverage for pydantic_ai_slim/pydantic_ai/models/mistral.py: 96.29%
269 statements
« prev ^ index » next coverage.py v7.6.7, created at 2025-01-25 16:43 +0000
« prev ^ index » next coverage.py v7.6.7, created at 2025-01-25 16:43 +0000
1from __future__ import annotations as _annotations
3import os
4from collections.abc import AsyncIterable, AsyncIterator, Iterable
5from contextlib import asynccontextmanager
6from dataclasses import dataclass, field
7from datetime import datetime, timezone
8from itertools import chain
9from typing import Any, Callable, Literal, Union, cast
11import pydantic_core
12from httpx import AsyncClient as AsyncHTTPClient, Timeout
13from typing_extensions import assert_never
15from .. import UnexpectedModelBehavior, _utils
16from .._utils import now_utc as _now_utc
17from ..messages import (
18 ModelMessage,
19 ModelRequest,
20 ModelResponse,
21 ModelResponsePart,
22 ModelResponseStreamEvent,
23 RetryPromptPart,
24 SystemPromptPart,
25 TextPart,
26 ToolCallPart,
27 ToolReturnPart,
28 UserPromptPart,
29)
30from ..result import Usage
31from ..settings import ModelSettings
32from ..tools import ToolDefinition
33from . import (
34 AgentModel,
35 Model,
36 StreamedResponse,
37 cached_async_http_client,
38 check_allow_model_requests,
39)
41try:
42 from mistralai import (
43 UNSET,
44 CompletionChunk as MistralCompletionChunk,
45 Content as MistralContent,
46 ContentChunk as MistralContentChunk,
47 FunctionCall as MistralFunctionCall,
48 Mistral,
49 OptionalNullable as MistralOptionalNullable,
50 TextChunk as MistralTextChunk,
51 ToolChoiceEnum as MistralToolChoiceEnum,
52 )
53 from mistralai.models import (
54 ChatCompletionResponse as MistralChatCompletionResponse,
55 CompletionEvent as MistralCompletionEvent,
56 Messages as MistralMessages,
57 Tool as MistralTool,
58 ToolCall as MistralToolCall,
59 )
60 from mistralai.models.assistantmessage import AssistantMessage as MistralAssistantMessage
61 from mistralai.models.function import Function as MistralFunction
62 from mistralai.models.systemmessage import SystemMessage as MistralSystemMessage
63 from mistralai.models.toolmessage import ToolMessage as MistralToolMessage
64 from mistralai.models.usermessage import UserMessage as MistralUserMessage
65 from mistralai.types.basemodel import Unset as MistralUnset
66 from mistralai.utils.eventstreaming import EventStreamAsync as MistralEventStreamAsync
67except ImportError as e:
68 raise ImportError(
69 'Please install `mistral` to use the Mistral model, '
70 "you can use the `mistral` optional group — `pip install 'pydantic-ai-slim[mistral]'`"
71 ) from e
73NamedMistralModels = Literal[
74 'mistral-large-latest', 'mistral-small-latest', 'codestral-latest', 'mistral-moderation-latest'
75]
76"""Latest / most popular named Mistral models."""
78MistralModelName = Union[NamedMistralModels, str]
79"""Possible Mistral model names.
81Since Mistral supports a variety of date-stamped models, we explicitly list the most popular models but
82allow any name in the type hints.
83Since [the Mistral docs](https://docs.mistral.ai/getting-started/models/models_overview/) for a full list.
84"""
87class MistralModelSettings(ModelSettings):
88 """Settings used for a Mistral model request."""
90 # This class is a placeholder for any future mistral-specific settings
93@dataclass(init=False)
94class MistralModel(Model):
95 """A model that uses Mistral.
97 Internally, this uses the [Mistral Python client](https://github.com/mistralai/client-python) to interact with the API.
99 [API Documentation](https://docs.mistral.ai/)
100 """
102 model_name: MistralModelName
103 client: Mistral = field(repr=False)
105 def __init__(
106 self,
107 model_name: MistralModelName,
108 *,
109 api_key: str | Callable[[], str | None] | None = None,
110 client: Mistral | None = None,
111 http_client: AsyncHTTPClient | None = None,
112 ):
113 """Initialize a Mistral model.
115 Args:
116 model_name: The name of the model to use.
117 api_key: The API key to use for authentication, if unset uses `MISTRAL_API_KEY` environment variable.
118 client: An existing `Mistral` client to use, if provided, `api_key` and `http_client` must be `None`.
119 http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
120 """
121 self.model_name = model_name
123 if client is not None:
124 assert http_client is None, 'Cannot provide both `mistral_client` and `http_client`'
125 assert api_key is None, 'Cannot provide both `mistral_client` and `api_key`'
126 self.client = client
127 else:
128 api_key = os.getenv('MISTRAL_API_KEY') if api_key is None else api_key
129 self.client = Mistral(api_key=api_key, async_client=http_client or cached_async_http_client())
131 async def agent_model(
132 self,
133 *,
134 function_tools: list[ToolDefinition],
135 allow_text_result: bool,
136 result_tools: list[ToolDefinition],
137 ) -> AgentModel:
138 """Create an agent model, this is called for each step of an agent run from Pydantic AI call."""
139 check_allow_model_requests()
140 return MistralAgentModel(
141 self.client,
142 self.model_name,
143 allow_text_result,
144 function_tools,
145 result_tools,
146 )
148 def name(self) -> str:
149 return f'mistral:{self.model_name}'
152@dataclass
153class MistralAgentModel(AgentModel):
154 """Implementation of `AgentModel` for Mistral models."""
156 client: Mistral
157 model_name: MistralModelName
158 allow_text_result: bool
159 function_tools: list[ToolDefinition]
160 result_tools: list[ToolDefinition]
161 json_mode_schema_prompt: str = """Answer in JSON Object, respect the format:\n```\n{schema}\n```\n"""
163 async def request(
164 self, messages: list[ModelMessage], model_settings: ModelSettings | None
165 ) -> tuple[ModelResponse, Usage]:
166 """Make a non-streaming request to the model from Pydantic AI call."""
167 response = await self._completions_create(messages, cast(MistralModelSettings, model_settings or {}))
168 return self._process_response(response), _map_usage(response)
170 @asynccontextmanager
171 async def request_stream(
172 self, messages: list[ModelMessage], model_settings: ModelSettings | None
173 ) -> AsyncIterator[StreamedResponse]:
174 """Make a streaming request to the model from Pydantic AI call."""
175 response = await self._stream_completions_create(messages, cast(MistralModelSettings, model_settings or {}))
176 async with response:
177 yield await self._process_streamed_response(self.result_tools, response)
179 async def _completions_create(
180 self, messages: list[ModelMessage], model_settings: MistralModelSettings
181 ) -> MistralChatCompletionResponse:
182 """Make a non-streaming request to the model."""
183 response = await self.client.chat.complete_async(
184 model=str(self.model_name),
185 messages=list(chain(*(self._map_message(m) for m in messages))),
186 n=1,
187 tools=self._map_function_and_result_tools_definition() or UNSET,
188 tool_choice=self._get_tool_choice(),
189 stream=False,
190 max_tokens=model_settings.get('max_tokens', UNSET),
191 temperature=model_settings.get('temperature', UNSET),
192 top_p=model_settings.get('top_p', 1),
193 timeout_ms=self._get_timeout_ms(model_settings.get('timeout')),
194 random_seed=model_settings.get('seed', UNSET),
195 )
196 assert response, 'A unexpected empty response from Mistral.'
197 return response
199 async def _stream_completions_create(
200 self,
201 messages: list[ModelMessage],
202 model_settings: MistralModelSettings,
203 ) -> MistralEventStreamAsync[MistralCompletionEvent]:
204 """Create a streaming completion request to the Mistral model."""
205 response: MistralEventStreamAsync[MistralCompletionEvent] | None
206 mistral_messages = list(chain(*(self._map_message(m) for m in messages)))
208 if self.result_tools and self.function_tools or self.function_tools:
209 # Function Calling
210 response = await self.client.chat.stream_async(
211 model=str(self.model_name),
212 messages=mistral_messages,
213 n=1,
214 tools=self._map_function_and_result_tools_definition() or UNSET,
215 tool_choice=self._get_tool_choice(),
216 temperature=model_settings.get('temperature', UNSET),
217 top_p=model_settings.get('top_p', 1),
218 max_tokens=model_settings.get('max_tokens', UNSET),
219 timeout_ms=self._get_timeout_ms(model_settings.get('timeout')),
220 presence_penalty=model_settings.get('presence_penalty'),
221 frequency_penalty=model_settings.get('frequency_penalty'),
222 )
224 elif self.result_tools:
225 # Json Mode
226 parameters_json_schemas = [tool.parameters_json_schema for tool in self.result_tools]
227 user_output_format_message = self._generate_user_output_format(parameters_json_schemas)
228 mistral_messages.append(user_output_format_message)
230 response = await self.client.chat.stream_async(
231 model=str(self.model_name),
232 messages=mistral_messages,
233 response_format={'type': 'json_object'},
234 stream=True,
235 )
237 else:
238 # Stream Mode
239 response = await self.client.chat.stream_async(
240 model=str(self.model_name),
241 messages=mistral_messages,
242 stream=True,
243 )
244 assert response, 'A unexpected empty response from Mistral.'
245 return response
247 def _get_tool_choice(self) -> MistralToolChoiceEnum | None:
248 """Get tool choice for the model.
250 - "auto": Default mode. Model decides if it uses the tool or not.
251 - "any": Select any tool.
252 - "none": Prevents tool use.
253 - "required": Forces tool use.
254 """
255 if not self.function_tools and not self.result_tools:
256 return None
257 elif not self.allow_text_result:
258 return 'required'
259 else:
260 return 'auto'
262 def _map_function_and_result_tools_definition(self) -> list[MistralTool] | None:
263 """Map function and result tools to MistralTool format.
265 Returns None if both function_tools and result_tools are empty.
266 """
267 all_tools: list[ToolDefinition] = self.function_tools + self.result_tools
268 tools = [
269 MistralTool(
270 function=MistralFunction(name=r.name, parameters=r.parameters_json_schema, description=r.description)
271 )
272 for r in all_tools
273 ]
274 return tools if tools else None
276 def _process_response(self, response: MistralChatCompletionResponse) -> ModelResponse:
277 """Process a non-streamed response, and prepare a message to return."""
278 assert response.choices, 'Unexpected empty response choice.'
280 if response.created:
281 timestamp = datetime.fromtimestamp(response.created, tz=timezone.utc)
282 else:
283 timestamp = _now_utc()
285 choice = response.choices[0]
286 content = choice.message.content
287 tool_calls = choice.message.tool_calls
289 parts: list[ModelResponsePart] = []
290 if text := _map_content(content):
291 parts.append(TextPart(content=text))
293 if isinstance(tool_calls, list):
294 for tool_call in tool_calls:
295 tool = _map_mistral_to_pydantic_tool_call(tool_call=tool_call)
296 parts.append(tool)
298 return ModelResponse(parts, model_name=self.model_name, timestamp=timestamp)
300 async def _process_streamed_response(
301 self,
302 result_tools: list[ToolDefinition],
303 response: MistralEventStreamAsync[MistralCompletionEvent],
304 ) -> StreamedResponse:
305 """Process a streamed response, and prepare a streaming response to return."""
306 peekable_response = _utils.PeekableAsyncStream(response)
307 first_chunk = await peekable_response.peek()
308 if isinstance(first_chunk, _utils.Unset): 308 ↛ 309line 308 didn't jump to line 309 because the condition on line 308 was never true
309 raise UnexpectedModelBehavior('Streamed response ended without content or tool calls')
311 if first_chunk.data.created:
312 timestamp = datetime.fromtimestamp(first_chunk.data.created, tz=timezone.utc)
313 else:
314 timestamp = datetime.now(tz=timezone.utc)
316 return MistralStreamedResponse(
317 _response=peekable_response,
318 _model_name=self.model_name,
319 _timestamp=timestamp,
320 _result_tools={c.name: c for c in result_tools},
321 )
323 @staticmethod
324 def _map_to_mistral_tool_call(t: ToolCallPart) -> MistralToolCall:
325 """Maps a pydantic-ai ToolCall to a MistralToolCall."""
326 return MistralToolCall(
327 id=t.tool_call_id,
328 type='function',
329 function=MistralFunctionCall(name=t.tool_name, arguments=t.args),
330 )
332 def _generate_user_output_format(self, schemas: list[dict[str, Any]]) -> MistralUserMessage:
333 """Get a message with an example of the expected output format."""
334 examples: list[dict[str, Any]] = []
335 for schema in schemas:
336 typed_dict_definition: dict[str, Any] = {}
337 for key, value in schema.get('properties', {}).items():
338 typed_dict_definition[key] = self._get_python_type(value)
339 examples.append(typed_dict_definition)
341 example_schema = examples[0] if len(examples) == 1 else examples
342 return MistralUserMessage(content=self.json_mode_schema_prompt.format(schema=example_schema))
344 @classmethod
345 def _get_python_type(cls, value: dict[str, Any]) -> str:
346 """Return a string representation of the Python type for a single JSON schema property.
348 This function handles recursion for nested arrays/objects and `anyOf`.
349 """
350 # 1) Handle anyOf first, because it's a different schema structure
351 if any_of := value.get('anyOf'):
352 # Simplistic approach: pick the first option in anyOf
353 # (In reality, you'd possibly want to merge or union types)
354 return f'Optional[{cls._get_python_type(any_of[0])}]'
356 # 2) If we have a top-level "type" field
357 value_type = value.get('type')
358 if not value_type:
359 # No explicit type; fallback
360 return 'Any'
362 # 3) Direct simple type mapping (string, integer, float, bool, None)
363 if value_type in SIMPLE_JSON_TYPE_MAPPING and value_type != 'array' and value_type != 'object':
364 return SIMPLE_JSON_TYPE_MAPPING[value_type]
366 # 4) Array: Recursively get the item type
367 if value_type == 'array':
368 items = value.get('items', {})
369 return f'list[{cls._get_python_type(items)}]'
371 # 5) Object: Check for additionalProperties
372 if value_type == 'object':
373 additional_properties = value.get('additionalProperties', {})
374 additional_properties_type = additional_properties.get('type')
375 if (
376 additional_properties_type in SIMPLE_JSON_TYPE_MAPPING
377 and additional_properties_type != 'array'
378 and additional_properties_type != 'object'
379 ):
380 # dict[str, bool/int/float/etc...]
381 return f'dict[str, {SIMPLE_JSON_TYPE_MAPPING[additional_properties_type]}]'
382 elif additional_properties_type == 'array':
383 array_items = additional_properties.get('items', {})
384 return f'dict[str, list[{cls._get_python_type(array_items)}]]'
385 elif additional_properties_type == 'object':
386 # nested dictionary of unknown shape
387 return 'dict[str, dict[str, Any]]'
388 else:
389 # If no additionalProperties type or something else, default to a generic dict
390 return 'dict[str, Any]'
392 # 6) Fallback
393 return 'Any'
395 @staticmethod
396 def _get_timeout_ms(timeout: Timeout | float | None) -> int | None:
397 """Convert a timeout to milliseconds."""
398 if timeout is None: 398 ↛ 400line 398 didn't jump to line 400 because the condition on line 398 was always true
399 return None
400 if isinstance(timeout, float):
401 return int(1000 * timeout)
402 raise NotImplementedError('Timeout object is not yet supported for MistralModel.')
404 @classmethod
405 def _map_user_message(cls, message: ModelRequest) -> Iterable[MistralMessages]:
406 for part in message.parts:
407 if isinstance(part, SystemPromptPart):
408 yield MistralSystemMessage(content=part.content)
409 elif isinstance(part, UserPromptPart):
410 yield MistralUserMessage(content=part.content)
411 elif isinstance(part, ToolReturnPart):
412 yield MistralToolMessage(
413 tool_call_id=part.tool_call_id,
414 content=part.model_response_str(),
415 )
416 elif isinstance(part, RetryPromptPart):
417 if part.tool_name is None: 417 ↛ 418line 417 didn't jump to line 418 because the condition on line 417 was never true
418 yield MistralUserMessage(content=part.model_response())
419 else:
420 yield MistralToolMessage(
421 tool_call_id=part.tool_call_id,
422 content=part.model_response(),
423 )
424 else:
425 assert_never(part)
427 @classmethod
428 def _map_message(cls, message: ModelMessage) -> Iterable[MistralMessages]:
429 """Just maps a `pydantic_ai.Message` to a `MistralMessage`."""
430 if isinstance(message, ModelRequest):
431 yield from cls._map_user_message(message)
432 elif isinstance(message, ModelResponse):
433 content_chunks: list[MistralContentChunk] = []
434 tool_calls: list[MistralToolCall] = []
436 for part in message.parts:
437 if isinstance(part, TextPart):
438 content_chunks.append(MistralTextChunk(text=part.content))
439 elif isinstance(part, ToolCallPart):
440 tool_calls.append(cls._map_to_mistral_tool_call(part))
441 else:
442 assert_never(part)
443 yield MistralAssistantMessage(content=content_chunks, tool_calls=tool_calls)
444 else:
445 assert_never(message)
448MistralToolCallId = Union[str, None]
451@dataclass
452class MistralStreamedResponse(StreamedResponse):
453 """Implementation of `StreamedResponse` for Mistral models."""
455 _response: AsyncIterable[MistralCompletionEvent]
456 _timestamp: datetime
457 _result_tools: dict[str, ToolDefinition]
459 _delta_content: str = field(default='', init=False)
461 async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
462 chunk: MistralCompletionEvent
463 async for chunk in self._response:
464 self._usage += _map_usage(chunk.data)
466 try:
467 choice = chunk.data.choices[0]
468 except IndexError:
469 continue
471 # Handle the text part of the response
472 content = choice.delta.content
473 text = _map_content(content)
474 if text:
475 # Attempt to produce a result tool call from the received text
476 if self._result_tools:
477 self._delta_content += text
478 maybe_tool_call_part = self._try_get_result_tool_from_text(self._delta_content, self._result_tools)
479 if maybe_tool_call_part:
480 yield self._parts_manager.handle_tool_call_part(
481 vendor_part_id='result',
482 tool_name=maybe_tool_call_part.tool_name,
483 args=maybe_tool_call_part.args_as_dict(),
484 tool_call_id=maybe_tool_call_part.tool_call_id,
485 )
486 else:
487 yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=text)
489 # Handle the explicit tool calls
490 for index, dtc in enumerate(choice.delta.tool_calls or []):
491 # It seems that mistral just sends full tool calls, so we just use them directly, rather than building
492 yield self._parts_manager.handle_tool_call_part(
493 vendor_part_id=index, tool_name=dtc.function.name, args=dtc.function.arguments, tool_call_id=dtc.id
494 )
496 def timestamp(self) -> datetime:
497 return self._timestamp
499 @staticmethod
500 def _try_get_result_tool_from_text(text: str, result_tools: dict[str, ToolDefinition]) -> ToolCallPart | None:
501 output_json: dict[str, Any] | None = pydantic_core.from_json(text, allow_partial='trailing-strings')
502 if output_json:
503 for result_tool in result_tools.values():
504 # NOTE: Additional verification to prevent JSON validation to crash in `_result.py`
505 # Ensures required parameters in the JSON schema are respected, especially for stream-based return types.
506 # Example with BaseModel and required fields.
507 if not MistralStreamedResponse._validate_required_json_schema(
508 output_json, result_tool.parameters_json_schema
509 ):
510 continue
512 # The following part_id will be thrown away
513 return ToolCallPart(tool_name=result_tool.name, args=output_json)
515 @staticmethod
516 def _validate_required_json_schema(json_dict: dict[str, Any], json_schema: dict[str, Any]) -> bool:
517 """Validate that all required parameters in the JSON schema are present in the JSON dictionary."""
518 required_params = json_schema.get('required', [])
519 properties = json_schema.get('properties', {})
521 for param in required_params:
522 if param not in json_dict:
523 return False
525 param_schema = properties.get(param, {})
526 param_type = param_schema.get('type')
527 param_items_type = param_schema.get('items', {}).get('type')
529 if param_type == 'array' and param_items_type:
530 if not isinstance(json_dict[param], list):
531 return False
532 for item in json_dict[param]:
533 if not isinstance(item, VALID_JSON_TYPE_MAPPING[param_items_type]):
534 return False
535 elif param_type and not isinstance(json_dict[param], VALID_JSON_TYPE_MAPPING[param_type]):
536 return False
538 if isinstance(json_dict[param], dict) and 'properties' in param_schema:
539 nested_schema = param_schema
540 if not MistralStreamedResponse._validate_required_json_schema(json_dict[param], nested_schema):
541 return False
543 return True
546VALID_JSON_TYPE_MAPPING: dict[str, Any] = {
547 'string': str,
548 'integer': int,
549 'number': float,
550 'boolean': bool,
551 'array': list,
552 'object': dict,
553 'null': type(None),
554}
556SIMPLE_JSON_TYPE_MAPPING = {
557 'string': 'str',
558 'integer': 'int',
559 'number': 'float',
560 'boolean': 'bool',
561 'array': 'list',
562 'null': 'None',
563}
566def _map_mistral_to_pydantic_tool_call(tool_call: MistralToolCall) -> ToolCallPart:
567 """Maps a MistralToolCall to a ToolCall."""
568 tool_call_id = tool_call.id or None
569 func_call = tool_call.function
571 return ToolCallPart(func_call.name, func_call.arguments, tool_call_id)
574def _map_usage(response: MistralChatCompletionResponse | MistralCompletionChunk) -> Usage:
575 """Maps a Mistral Completion Chunk or Chat Completion Response to a Usage."""
576 if response.usage: 576 ↛ 584line 576 didn't jump to line 584 because the condition on line 576 was always true
577 return Usage(
578 request_tokens=response.usage.prompt_tokens,
579 response_tokens=response.usage.completion_tokens,
580 total_tokens=response.usage.total_tokens,
581 details=None,
582 )
583 else:
584 return Usage()
587def _map_content(content: MistralOptionalNullable[MistralContent]) -> str | None:
588 """Maps the delta content from a Mistral Completion Chunk to a string or None."""
589 result: str | None = None
591 if isinstance(content, MistralUnset) or not content:
592 result = None
593 elif isinstance(content, list):
594 for chunk in content:
595 if isinstance(chunk, MistralTextChunk): 595 ↛ 598line 595 didn't jump to line 598 because the condition on line 595 was always true
596 result = result or '' + chunk.text
597 else:
598 assert False, f'Other data types like (Image, Reference) are not yet supported, got {type(chunk)}'
599 elif isinstance(content, str): 599 ↛ 603line 599 didn't jump to line 603 because the condition on line 599 was always true
600 result = content
602 # Note: Check len to handle potential mismatch between function calls and responses from the API. (`msg: not the same number of function class and responses`)
603 if result and len(result) == 0: 603 ↛ 604line 603 didn't jump to line 604 because the condition on line 603 was never true
604 result = None
606 return result