Coverage for pydantic_ai_slim/pydantic_ai/models/fallback.py: 96.34%
68 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-28 17:27 +0000
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-28 17:27 +0000
1from __future__ import annotations as _annotations
3from collections.abc import AsyncIterator
4from contextlib import AsyncExitStack, asynccontextmanager, suppress
5from dataclasses import dataclass, field
6from typing import TYPE_CHECKING, Callable
8from opentelemetry.trace import get_current_span
10from pydantic_ai.models.instrumented import InstrumentedModel
12from ..exceptions import FallbackExceptionGroup, ModelHTTPError
13from . import KnownModelName, Model, ModelRequestParameters, StreamedResponse, infer_model
15if TYPE_CHECKING:
16 from ..messages import ModelMessage, ModelResponse
17 from ..settings import ModelSettings
18 from ..usage import Usage
21@dataclass(init=False)
22class FallbackModel(Model):
23 """A model that uses one or more fallback models upon failure.
25 Apart from `__init__`, all methods are private or match those of the base class.
26 """
28 models: list[Model]
30 _model_name: str = field(repr=False)
31 _fallback_on: Callable[[Exception], bool]
33 def __init__(
34 self,
35 default_model: Model | KnownModelName,
36 *fallback_models: Model | KnownModelName,
37 fallback_on: Callable[[Exception], bool] | tuple[type[Exception], ...] = (ModelHTTPError,),
38 ):
39 """Initialize a fallback model instance.
41 Args:
42 default_model: The name or instance of the default model to use.
43 fallback_models: The names or instances of the fallback models to use upon failure.
44 fallback_on: A callable or tuple of exceptions that should trigger a fallback.
45 """
46 self.models = [infer_model(default_model), *[infer_model(m) for m in fallback_models]]
48 if isinstance(fallback_on, tuple):
49 self._fallback_on = _default_fallback_condition_factory(fallback_on)
50 else:
51 self._fallback_on = fallback_on
53 async def request(
54 self,
55 messages: list[ModelMessage],
56 model_settings: ModelSettings | None,
57 model_request_parameters: ModelRequestParameters,
58 ) -> tuple[ModelResponse, Usage]:
59 """Try each model in sequence until one succeeds.
61 In case of failure, raise a FallbackExceptionGroup with all exceptions.
62 """
63 exceptions: list[Exception] = []
65 for model in self.models:
66 try:
67 response, usage = await model.request(messages, model_settings, model_request_parameters)
68 except Exception as exc:
69 if self._fallback_on(exc):
70 exceptions.append(exc)
71 continue
72 raise exc
74 self._set_span_attributes(model)
75 return response, usage
77 raise FallbackExceptionGroup('All models from FallbackModel failed', exceptions)
79 @asynccontextmanager
80 async def request_stream(
81 self,
82 messages: list[ModelMessage],
83 model_settings: ModelSettings | None,
84 model_request_parameters: ModelRequestParameters,
85 ) -> AsyncIterator[StreamedResponse]:
86 """Try each model in sequence until one succeeds."""
87 exceptions: list[Exception] = []
89 for model in self.models:
90 async with AsyncExitStack() as stack:
91 try:
92 response = await stack.enter_async_context(
93 model.request_stream(messages, model_settings, model_request_parameters)
94 )
95 except Exception as exc:
96 if self._fallback_on(exc): 96 ↛ 99line 96 didn't jump to line 99 because the condition on line 96 was always true
97 exceptions.append(exc)
98 continue
99 raise exc
101 self._set_span_attributes(model)
102 yield response
103 return
105 raise FallbackExceptionGroup('All models from FallbackModel failed', exceptions)
107 def _set_span_attributes(self, model: Model):
108 with suppress(Exception):
109 span = get_current_span()
110 if span.is_recording():
111 attributes = getattr(span, 'attributes', {})
112 if attributes.get('gen_ai.request.model') == self.model_name: 112 ↛ exitline 112 didn't jump to the function exit
113 span.set_attributes(InstrumentedModel.model_attributes(model))
115 @property
116 def model_name(self) -> str:
117 """The model name."""
118 return f'fallback:{",".join(model.model_name for model in self.models)}'
120 @property
121 def system(self) -> str:
122 return f'fallback:{",".join(model.system for model in self.models)}'
124 @property
125 def base_url(self) -> str | None:
126 return self.models[0].base_url
129def _default_fallback_condition_factory(exceptions: tuple[type[Exception], ...]) -> Callable[[Exception], bool]:
130 """Create a default fallback condition for the given exceptions."""
132 def fallback_condition(exception: Exception) -> bool:
133 return isinstance(exception, exceptions)
135 return fallback_condition