Coverage for pydantic_ai_slim/pydantic_ai/models/fallback.py: 96.34%

68 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-28 17:27 +0000

1from __future__ import annotations as _annotations 

2 

3from collections.abc import AsyncIterator 

4from contextlib import AsyncExitStack, asynccontextmanager, suppress 

5from dataclasses import dataclass, field 

6from typing import TYPE_CHECKING, Callable 

7 

8from opentelemetry.trace import get_current_span 

9 

10from pydantic_ai.models.instrumented import InstrumentedModel 

11 

12from ..exceptions import FallbackExceptionGroup, ModelHTTPError 

13from . import KnownModelName, Model, ModelRequestParameters, StreamedResponse, infer_model 

14 

15if TYPE_CHECKING: 

16 from ..messages import ModelMessage, ModelResponse 

17 from ..settings import ModelSettings 

18 from ..usage import Usage 

19 

20 

21@dataclass(init=False) 

22class FallbackModel(Model): 

23 """A model that uses one or more fallback models upon failure. 

24 

25 Apart from `__init__`, all methods are private or match those of the base class. 

26 """ 

27 

28 models: list[Model] 

29 

30 _model_name: str = field(repr=False) 

31 _fallback_on: Callable[[Exception], bool] 

32 

33 def __init__( 

34 self, 

35 default_model: Model | KnownModelName, 

36 *fallback_models: Model | KnownModelName, 

37 fallback_on: Callable[[Exception], bool] | tuple[type[Exception], ...] = (ModelHTTPError,), 

38 ): 

39 """Initialize a fallback model instance. 

40 

41 Args: 

42 default_model: The name or instance of the default model to use. 

43 fallback_models: The names or instances of the fallback models to use upon failure. 

44 fallback_on: A callable or tuple of exceptions that should trigger a fallback. 

45 """ 

46 self.models = [infer_model(default_model), *[infer_model(m) for m in fallback_models]] 

47 

48 if isinstance(fallback_on, tuple): 

49 self._fallback_on = _default_fallback_condition_factory(fallback_on) 

50 else: 

51 self._fallback_on = fallback_on 

52 

53 async def request( 

54 self, 

55 messages: list[ModelMessage], 

56 model_settings: ModelSettings | None, 

57 model_request_parameters: ModelRequestParameters, 

58 ) -> tuple[ModelResponse, Usage]: 

59 """Try each model in sequence until one succeeds. 

60 

61 In case of failure, raise a FallbackExceptionGroup with all exceptions. 

62 """ 

63 exceptions: list[Exception] = [] 

64 

65 for model in self.models: 

66 try: 

67 response, usage = await model.request(messages, model_settings, model_request_parameters) 

68 except Exception as exc: 

69 if self._fallback_on(exc): 

70 exceptions.append(exc) 

71 continue 

72 raise exc 

73 

74 self._set_span_attributes(model) 

75 return response, usage 

76 

77 raise FallbackExceptionGroup('All models from FallbackModel failed', exceptions) 

78 

79 @asynccontextmanager 

80 async def request_stream( 

81 self, 

82 messages: list[ModelMessage], 

83 model_settings: ModelSettings | None, 

84 model_request_parameters: ModelRequestParameters, 

85 ) -> AsyncIterator[StreamedResponse]: 

86 """Try each model in sequence until one succeeds.""" 

87 exceptions: list[Exception] = [] 

88 

89 for model in self.models: 

90 async with AsyncExitStack() as stack: 

91 try: 

92 response = await stack.enter_async_context( 

93 model.request_stream(messages, model_settings, model_request_parameters) 

94 ) 

95 except Exception as exc: 

96 if self._fallback_on(exc): 96 ↛ 99line 96 didn't jump to line 99 because the condition on line 96 was always true

97 exceptions.append(exc) 

98 continue 

99 raise exc 

100 

101 self._set_span_attributes(model) 

102 yield response 

103 return 

104 

105 raise FallbackExceptionGroup('All models from FallbackModel failed', exceptions) 

106 

107 def _set_span_attributes(self, model: Model): 

108 with suppress(Exception): 

109 span = get_current_span() 

110 if span.is_recording(): 

111 attributes = getattr(span, 'attributes', {}) 

112 if attributes.get('gen_ai.request.model') == self.model_name: 112 ↛ exitline 112 didn't jump to the function exit

113 span.set_attributes(InstrumentedModel.model_attributes(model)) 

114 

115 @property 

116 def model_name(self) -> str: 

117 """The model name.""" 

118 return f'fallback:{",".join(model.model_name for model in self.models)}' 

119 

120 @property 

121 def system(self) -> str: 

122 return f'fallback:{",".join(model.system for model in self.models)}' 

123 

124 @property 

125 def base_url(self) -> str | None: 

126 return self.models[0].base_url 

127 

128 

129def _default_fallback_condition_factory(exceptions: tuple[type[Exception], ...]) -> Callable[[Exception], bool]: 

130 """Create a default fallback condition for the given exceptions.""" 

131 

132 def fallback_condition(exception: Exception) -> bool: 

133 return isinstance(exception, exceptions) 

134 

135 return fallback_condition