Coverage for pydantic_evals/pydantic_evals/evaluators/evaluator.py: 99.24%
101 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-28 17:27 +0000
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-28 17:27 +0000
1from __future__ import annotations
3import inspect
4from abc import ABCMeta, abstractmethod
5from collections.abc import Awaitable, Mapping
6from dataclasses import MISSING, dataclass, fields
7from typing import Any, Generic, Union, cast
9from pydantic import (
10 ConfigDict,
11 TypeAdapter,
12 ValidationError,
13 model_serializer,
14)
15from pydantic_core import to_jsonable_python
16from pydantic_core.core_schema import SerializationInfo
17from typing_extensions import TypeVar
19from .._utils import get_event_loop
20from ._spec import EvaluatorSpec
21from .context import EvaluatorContext
23EvaluationScalar = Union[bool, int, float, str]
24"""The most primitive output allowed as an output from an Evaluator.
26`int` and `float` are treated as scores, `str` as labels, and `bool` as assertions.
27"""
30@dataclass
31class EvaluationReason:
32 """The result of running an evaluator with an optional explanation.
34 Contains a scalar value and an optional "reason" explaining the value.
36 Args:
37 value: The scalar result of the evaluation (boolean, integer, float, or string).
38 reason: An optional explanation of the evaluation result.
39 """
41 value: EvaluationScalar
42 reason: str | None = None
45EvaluatorOutput = Union[EvaluationScalar, EvaluationReason, Mapping[str, Union[EvaluationScalar, EvaluationReason]]]
46"""Type for the output of an evaluator, which can be a scalar, an EvaluationReason, or a mapping of names to either."""
49# TODO(DavidM): Add bound=EvaluationScalar to the following typevar after we upgrade to pydantic 2.11
50EvaluationScalarT = TypeVar('EvaluationScalarT', default=EvaluationScalar, covariant=True)
51"""Type variable for the scalar result type of an evaluation."""
53T = TypeVar('T')
56@dataclass
57class EvaluationResult(Generic[EvaluationScalarT]):
58 """The details of an individual evaluation result.
60 Contains the name, value, reason, and source evaluator for a single evaluation.
62 Args:
63 name: The name of the evaluation.
64 value: The scalar result of the evaluation.
65 reason: An optional explanation of the evaluation result.
66 source: The evaluator that produced this result.
67 """
69 name: str
70 value: EvaluationScalarT
71 reason: str | None
72 source: Evaluator
74 def downcast(self, *value_types: type[T]) -> EvaluationResult[T] | None:
75 """Attempt to downcast this result to a more specific type.
77 Args:
78 *value_types: The types to check the value against.
80 Returns:
81 A downcast version of this result if the value is an instance of one of the given types,
82 otherwise None.
83 """
84 # Check if value matches any of the target types, handling bool as a special case
85 for value_type in value_types:
86 if isinstance(self.value, value_type):
87 # Only match bool with explicit bool type
88 if isinstance(self.value, bool) and value_type is not bool:
89 continue
90 return cast(EvaluationResult[T], self)
91 return None
94# Evaluators are contravariant in all of its parameters.
95InputsT = TypeVar('InputsT', default=Any, contravariant=True)
96"""Type variable for the inputs type of the task being evaluated."""
98OutputT = TypeVar('OutputT', default=Any, contravariant=True)
99"""Type variable for the output type of the task being evaluated."""
101MetadataT = TypeVar('MetadataT', default=Any, contravariant=True)
102"""Type variable for the metadata type of the task being evaluated."""
105class _StrictABCMeta(ABCMeta):
106 """An ABC-like metaclass that goes further and disallows even defining abstract subclasses."""
108 def __new__(mcls, name: str, bases: tuple[type, ...], namespace: dict[str, Any], /, **kwargs: Any):
109 result = super().__new__(mcls, name, bases, namespace, **kwargs)
110 # Check if this class is a proper subclass of a _StrictABC instance
111 is_proper_subclass = any(isinstance(c, _StrictABCMeta) for c in result.__mro__[1:])
112 if is_proper_subclass and result.__abstractmethods__:
113 abstractmethods = ', '.join([f'{m!r}' for m in result.__abstractmethods__])
114 raise TypeError(f'{name} must implement all abstract methods: {abstractmethods}')
115 return result
118@dataclass
119class Evaluator(Generic[InputsT, OutputT, MetadataT], metaclass=_StrictABCMeta):
120 """Base class for all evaluators.
122 Evaluators can assess the performance of a task in a variety of ways, as a function of the EvaluatorContext.
124 Subclasses must implement the `evaluate` method. Note it can be defined with either `def` or `async def`.
126 Example:
127 ```python
128 @dataclass
129 class ExactMatch(Evaluator[Any, Any, Any]):
130 def evaluate(self, ctx: EvaluatorContext) -> bool:
131 return ctx.actual_output == ctx.expected_output
132 ```
133 """
135 __pydantic_config__ = ConfigDict(arbitrary_types_allowed=True)
137 @classmethod
138 def name(cls) -> str:
139 """Return the 'name' of this Evaluator to use during serialization.
141 Returns:
142 The name of the Evaluator, which is typically the class name.
143 """
144 # Note: if we wanted to prefer snake_case, we could use:
145 # from pydantic.alias_generators import to_snake
146 # return to_snake(cls.__name__)
147 return cls.__name__
149 @abstractmethod
150 def evaluate(
151 self, ctx: EvaluatorContext[InputsT, OutputT, MetadataT]
152 ) -> EvaluatorOutput | Awaitable[EvaluatorOutput]: # pragma: no cover
153 """Evaluate the task output in the given context.
155 This is the main evaluation method that subclasses must implement. It can be either synchronous
156 or asynchronous, returning either an EvaluatorOutput directly or an Awaitable[EvaluatorOutput].
158 Args:
159 ctx: The context containing the inputs, outputs, and metadata for evaluation.
161 Returns:
162 The evaluation result, which can be a scalar value, an EvaluationReason, or a mapping
163 of evaluation names to either of those. Can be returned either synchronously or as an
164 awaitable for asynchronous evaluation.
165 """
166 raise NotImplementedError('You must implement `evaluate`.')
168 def evaluate_sync(self, ctx: EvaluatorContext[InputsT, OutputT, MetadataT]) -> EvaluatorOutput:
169 """Run the evaluator synchronously, handling both sync and async implementations.
171 This method ensures synchronous execution by running any async evaluate implementation
172 to completion using run_until_complete.
174 Args:
175 ctx: The context containing the inputs, outputs, and metadata for evaluation.
177 Returns:
178 The evaluation result, which can be a scalar value, an EvaluationReason, or a mapping
179 of evaluation names to either of those.
180 """
181 output = self.evaluate(ctx)
182 if inspect.iscoroutine(output): # pragma: no cover
183 return get_event_loop().run_until_complete(output)
184 else:
185 return cast(EvaluatorOutput, output)
187 async def evaluate_async(self, ctx: EvaluatorContext[InputsT, OutputT, MetadataT]) -> EvaluatorOutput:
188 """Run the evaluator asynchronously, handling both sync and async implementations.
190 This method ensures asynchronous execution by properly awaiting any async evaluate
191 implementation. For synchronous implementations, it returns the result directly.
193 Args:
194 ctx: The context containing the inputs, outputs, and metadata for evaluation.
196 Returns:
197 The evaluation result, which can be a scalar value, an EvaluationReason, or a mapping
198 of evaluation names to either of those.
199 """
200 # Note: If self.evaluate is synchronous, but you need to prevent this from blocking, override this method with:
201 # return await anyio.to_thread.run_sync(self.evaluate, ctx)
202 output = self.evaluate(ctx)
203 if inspect.iscoroutine(output):
204 return await output
205 else:
206 return cast(EvaluatorOutput, output)
208 @model_serializer(mode='plain')
209 def serialize(self, info: SerializationInfo) -> Any:
210 """Serialize this Evaluator to a JSON-serializable form.
212 Returns:
213 A JSON-serializable representation of this evaluator as an EvaluatorSpec.
214 """
215 raw_arguments: dict[str, Any] = {}
216 for field in fields(self):
217 value = getattr(self, field.name)
218 # always exclude defaults:
219 if field.default is not MISSING:
220 if value == field.default:
221 continue
222 if field.default_factory is not MISSING:
223 if value == field.default_factory(): 223 ↛ 225line 223 didn't jump to line 225 because the condition on line 223 was always true
224 continue
225 raw_arguments[field.name] = value
227 arguments: None | tuple[Any,] | dict[str, Any]
228 if len(raw_arguments) == 0:
229 arguments = None
230 elif len(raw_arguments) == 1:
231 arguments = (next(iter(raw_arguments.values())),)
232 else:
233 arguments = raw_arguments
234 return to_jsonable_python(EvaluatorSpec(name=self.name(), arguments=arguments), context=info.context)
237async def run_evaluator(
238 evaluator: Evaluator[InputsT, OutputT, MetadataT], ctx: EvaluatorContext[InputsT, OutputT, MetadataT]
239) -> list[EvaluationResult]:
240 """Run an evaluator and return the results.
242 This function runs an evaluator on the given context and processes the results into
243 a standardized format.
245 Args:
246 evaluator: The evaluator to run.
247 ctx: The context containing the inputs, outputs, and metadata for evaluation.
249 Returns:
250 A list of evaluation results.
252 Raises:
253 ValueError: If the evaluator returns a value of an invalid type.
254 """
255 raw_results = await evaluator.evaluate_async(ctx)
257 try:
258 results = _EVALUATOR_OUTPUT_ADAPTER.validate_python(raw_results)
259 except ValidationError as e:
260 raise ValueError(f'{evaluator!r}.evaluate returned a value of an invalid type: {raw_results!r}.') from e
262 results = _convert_to_mapping(results, scalar_name=evaluator.name())
264 details: list[EvaluationResult] = []
265 for name, result in results.items():
266 if not isinstance(result, EvaluationReason):
267 result = EvaluationReason(value=result)
268 details.append(EvaluationResult(name=name, value=result.value, reason=result.reason, source=evaluator))
270 return details
273_EVALUATOR_OUTPUT_ADAPTER = TypeAdapter[EvaluatorOutput](EvaluatorOutput)
276def _convert_to_mapping(
277 result: EvaluatorOutput, *, scalar_name: str
278) -> Mapping[str, EvaluationScalar | EvaluationReason]:
279 """Convert an evaluator output to a mapping from names to scalar values or evaluation reasons.
281 Args:
282 result: The evaluator output to convert.
283 scalar_name: The name to use for a scalar result.
285 Returns:
286 A mapping from names to scalar values or evaluation reasons.
287 """
288 if isinstance(result, Mapping):
289 return result
290 return {scalar_name: result}