Coverage for pydantic_evals/pydantic_evals/dataset.py: 97.76%
382 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-28 17:27 +0000
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-28 17:27 +0000
1"""Dataset management for pydantic evals.
3This module provides functionality for creating, loading, saving, and evaluating datasets of test cases.
4Each case must have inputs, and can optionally have a name, expected output, metadata, and case-specific evaluators.
6Datasets can be loaded from and saved to YAML or JSON files, and can be evaluated against
7a task function to produce an evaluation report.
8"""
10from __future__ import annotations as _annotations
12import functools
13import inspect
14import sys
15import time
16import warnings
17from collections.abc import Awaitable, Mapping, Sequence
18from contextlib import AsyncExitStack
19from contextvars import ContextVar
20from dataclasses import dataclass, field
21from pathlib import Path
22from typing import Any, Callable, Generic, Literal, Union, cast
24import anyio
25import logfire_api
26import yaml
27from pydantic import BaseModel, ConfigDict, Field, TypeAdapter, ValidationError, model_serializer
28from pydantic._internal import _typing_extra
29from pydantic_core import to_json, to_jsonable_python
30from pydantic_core.core_schema import SerializationInfo, SerializerFunctionWrapHandler
31from typing_extensions import NotRequired, Self, TypedDict, TypeVar
33from pydantic_evals._utils import get_event_loop
35from ._utils import get_unwrapped_function_name, task_group_gather
36from .evaluators import EvaluationResult, Evaluator, run_evaluator
37from .evaluators._spec import EvaluatorSpec
38from .evaluators.common import DEFAULT_EVALUATORS
39from .evaluators.context import EvaluatorContext
40from .otel import SpanTree
41from .otel._context_in_memory_span_exporter import context_subtree
42from .reporting import EvaluationReport, ReportCase, ReportCaseAggregate
44if sys.version_info < (3, 11): # pragma: no cover
45 from exceptiongroup import ExceptionGroup
46else:
47 ExceptionGroup = ExceptionGroup
49# while waiting for https://github.com/pydantic/logfire/issues/745
50try:
51 import logfire._internal.stack_info
52except ImportError: # pragma: no cover
53 pass
54else:
55 from pathlib import Path
57 logfire._internal.stack_info.NON_USER_CODE_PREFIXES += (str(Path(__file__).parent.absolute()),)
59_logfire = logfire_api.Logfire(otel_scope='pydantic-evals')
61InputsT = TypeVar('InputsT', default=Any)
62"""Generic type for the inputs to the task being evaluated."""
63OutputT = TypeVar('OutputT', default=Any)
64"""Generic type for the expected output of the task being evaluated."""
65MetadataT = TypeVar('MetadataT', default=Any)
66"""Generic type for the metadata associated with the task being evaluated."""
68DEFAULT_DATASET_PATH = './test_cases.yaml'
69"""Default path for saving/loading datasets."""
70DEFAULT_SCHEMA_PATH_TEMPLATE = './{stem}_schema.json'
71"""Default template for schema file paths, where {stem} is replaced with the dataset filename stem."""
72_YAML_SCHEMA_LINE_PREFIX = '# yaml-language-server: $schema='
75class _CaseModel(BaseModel, Generic[InputsT, OutputT, MetadataT], extra='forbid'):
76 """Internal model for a case, used for serialization/deserialization."""
78 name: str | None = None
79 inputs: InputsT
80 metadata: MetadataT | None = None
81 expected_output: OutputT | None = None
82 evaluators: list[EvaluatorSpec] = Field(default_factory=list)
85class _DatasetModel(BaseModel, Generic[InputsT, OutputT, MetadataT], extra='forbid'):
86 """Internal model for a dataset, used for serialization/deserialization."""
88 # $schema is included to avoid validation fails from the `$schema` key, see `_add_json_schema` below for context
89 json_schema_path: str | None = Field(default=None, alias='$schema')
90 cases: list[_CaseModel[InputsT, OutputT, MetadataT]]
91 evaluators: list[EvaluatorSpec] = Field(default_factory=list)
94@dataclass(init=False)
95class Case(Generic[InputsT, OutputT, MetadataT]):
96 """A single row of a [`Dataset`][pydantic_evals.Dataset].
98 Each case represents a single test scenario with inputs to test. A case may optionally specify a name, expected
99 outputs to compare against, and arbitrary metadata.
101 Cases can also have their own specific evaluators which are run in addition to dataset-level evaluators.
103 Example:
104 ```python
105 case = Case(
106 name="Simple addition",
107 inputs={"a": 1, "b": 2},
108 expected_output=3,
109 metadata={"description": "Tests basic addition"}
110 )
111 ```
112 """
114 name: str | None
115 """Name of the case. This is used to identify the case in the report and can be used to filter cases."""
116 inputs: InputsT
117 """Inputs to the task. This is the input to the task that will be evaluated."""
118 metadata: MetadataT | None
119 """Metadata to be used in the evaluation.
121 This can be used to provide additional information about the case to the evaluators.
122 """
123 expected_output: OutputT | None
124 """Expected output of the task. This is the expected output of the task that will be evaluated."""
125 evaluators: list[Evaluator[InputsT, OutputT, MetadataT]]
126 """Evaluators to be used just on this case."""
128 def __init__(
129 self,
130 *,
131 name: str | None = None,
132 inputs: InputsT,
133 metadata: MetadataT | None = None,
134 expected_output: OutputT | None = None,
135 evaluators: tuple[Evaluator[InputsT, OutputT, MetadataT], ...] = (),
136 ):
137 """Initialize a new test case.
139 Args:
140 name: Optional name for the case. If not provided, a generic name will be assigned when added to a dataset.
141 inputs: The inputs to the task being evaluated.
142 metadata: Optional metadata for the case, which can be used by evaluators.
143 expected_output: Optional expected output of the task, used for comparison in evaluators.
144 evaluators: Tuple of evaluators specific to this case. These are in addition to any
145 dataset-level evaluators.
147 """
148 # Note: `evaluators` must be a tuple instead of Sequence due to misbehavior with pyright's generic parameter
149 # inference if it has type `Sequence`
150 self.name = name
151 self.inputs = inputs
152 self.metadata = metadata
153 self.expected_output = expected_output
154 self.evaluators = list(evaluators)
157# TODO: Consider making one or more of the following changes to this type:
158# * Add `task: Callable[[InputsT], Awaitable[OutputT]` as a field
159# * Add `inputs_type`, `output_type`, etc. as kwargs on `__init__`
160# * Rename to `Evaluation`
161# TODO: Allow `task` to be sync _or_ async
162class Dataset(BaseModel, Generic[InputsT, OutputT, MetadataT], extra='forbid', arbitrary_types_allowed=True):
163 """A dataset of test [cases][pydantic_evals.Case].
165 Datasets allow you to organize a collection of test cases and evaluate them against a task function.
166 They can be loaded from and saved to YAML or JSON files, and can have dataset-level evaluators that
167 apply to all cases.
169 Example:
170 ```python
171 # Create a dataset with two test cases
172 dataset = Dataset(
173 cases=[
174 Case(name="test1", inputs={"text": "Hello"}, expected_output="HELLO"),
175 Case(name="test2", inputs={"text": "World"}, expected_output="WORLD"),
176 ],
177 evaluators=[ExactMatch()]
178 )
180 # Evaluate the dataset against a task function
181 async def uppercase(inputs: dict) -> str:
182 return inputs["text"].upper()
184 report = await dataset.evaluate(uppercase)
185 ```
186 """
188 cases: list[Case[InputsT, OutputT, MetadataT]]
189 """List of test cases in the dataset."""
190 evaluators: list[Evaluator[InputsT, OutputT, MetadataT]] = []
191 """List of evaluators to be used on all cases in the dataset."""
193 def __init__(
194 self,
195 *,
196 cases: Sequence[Case[InputsT, OutputT, MetadataT]],
197 evaluators: Sequence[Evaluator[InputsT, OutputT, MetadataT]] = (),
198 ):
199 """Initialize a new dataset with test cases and optional evaluators.
201 Args:
202 cases: Sequence of test cases to include in the dataset.
203 evaluators: Optional sequence of evaluators to apply to all cases in the dataset.
204 """
205 case_names = set[str]()
206 for case in cases:
207 if case.name is None:
208 continue
209 if case.name in case_names:
210 raise ValueError(f'Duplicate case name: {case.name!r}')
211 case_names.add(case.name)
213 super().__init__(
214 cases=cases,
215 evaluators=list(evaluators),
216 )
218 async def evaluate(
219 self, task: Callable[[InputsT], Awaitable[OutputT]], name: str | None = None, max_concurrency: int | None = None
220 ) -> EvaluationReport:
221 """Evaluates the test cases in the dataset using the given task.
223 This method runs the task on each case in the dataset, applies evaluators,
224 and collects results into a report. Cases are run concurrently, limited by `max_concurrency` if specified.
226 Args:
227 task: The task to evaluate. This should be a callable that takes the inputs of the case
228 and returns the output.
229 name: The name of the task being evaluated, this is used to identify the task in the report.
230 If omitted, the name of the task function will be used.
231 max_concurrency: The maximum number of concurrent evaluations of the task to allow.
232 If None, all cases will be evaluated concurrently.
234 Returns:
235 A report containing the results of the evaluation.
236 """
237 name = name or get_unwrapped_function_name(task)
239 limiter = anyio.Semaphore(max_concurrency) if max_concurrency is not None else AsyncExitStack()
240 with _logfire.span('evaluate {name}', name=name) as eval_span:
242 async def _handle_case(case: Case[InputsT, OutputT, MetadataT], report_case_name: str):
243 async with limiter:
244 return await _run_task_and_evaluators(task, case, report_case_name, self.evaluators)
246 report = EvaluationReport(
247 name=name,
248 cases=await task_group_gather(
249 [
250 lambda case=case, i=i: _handle_case(case, case.name or f'Case {i}')
251 for i, case in enumerate(self.cases, 1)
252 ]
253 ),
254 )
255 # TODO(DavidM): This attribute will be too big in general; remove it once we can use child spans in details panel:
256 eval_span.set_attribute('cases', report.cases)
257 # TODO(DavidM): Remove this 'averages' attribute once we compute it in the details panel
258 eval_span.set_attribute('averages', ReportCaseAggregate.average(report.cases))
260 return report
262 def evaluate_sync(
263 self, task: Callable[[InputsT], Awaitable[OutputT]], name: str | None = None, max_concurrency: int | None = None
264 ) -> EvaluationReport: # pragma: no cover
265 """Evaluates the test cases in the dataset using the given task.
267 This is a synchronous wrapper around [`evaluate`][pydantic_evals.Dataset.evaluate] provided for convenience.
269 Args:
270 task: The task to evaluate. This should be a callable that takes the inputs of the case
271 and returns the output.
272 name: The name of the task being evaluated, this is used to identify the task in the report.
273 If omitted, the name of the task function will be used.
274 max_concurrency: The maximum number of concurrent evaluations of the task to allow.
275 If None, all cases will be evaluated concurrently.
277 Returns:
278 A report containing the results of the evaluation.
279 """
280 return get_event_loop().run_until_complete(self.evaluate(task, name=name, max_concurrency=max_concurrency))
282 def add_case(
283 self,
284 *,
285 name: str | None = None,
286 inputs: InputsT,
287 metadata: MetadataT | None = None,
288 expected_output: OutputT | None = None,
289 evaluators: tuple[Evaluator[InputsT, OutputT, MetadataT], ...] = (),
290 ) -> None:
291 """Adds a case to the dataset.
293 This is a convenience method for creating a [`Case`][pydantic_evals.Case] and adding it to the dataset.
295 Args:
296 name: Optional name for the case. If not provided, a generic name will be assigned.
297 inputs: The inputs to the task being evaluated.
298 metadata: Optional metadata for the case, which can be used by evaluators.
299 expected_output: The expected output of the task, used for comparison in evaluators.
300 evaluators: Tuple of evaluators specific to this case, in addition to dataset-level evaluators.
301 """
302 if name in {case.name for case in self.cases}:
303 raise ValueError(f'Duplicate case name: {name!r}')
305 case = Case[InputsT, OutputT, MetadataT](
306 name=name,
307 inputs=inputs,
308 metadata=metadata,
309 expected_output=expected_output,
310 evaluators=evaluators,
311 )
312 self.cases.append(case)
314 def add_evaluator(
315 self,
316 evaluator: Evaluator[InputsT, OutputT, MetadataT],
317 specific_case: str | None = None,
318 ) -> None:
319 """Adds an evaluator to the dataset or a specific case.
321 Args:
322 evaluator: The evaluator to add.
323 specific_case: If provided, the evaluator will only be added to the case with this name.
324 If None, the evaluator will be added to all cases in the dataset.
326 Raises:
327 ValueError: If `specific_case` is provided but no case with that name exists in the dataset.
328 """
329 if specific_case is None:
330 self.evaluators.append(evaluator)
331 else:
332 # If this is too slow, we could try to add a case lookup dict.
333 # Note that if we do that, we'd need to make the cases list private to prevent modification.
334 added = False
335 for case in self.cases:
336 if case.name == specific_case:
337 case.evaluators.append(evaluator)
338 added = True
339 if not added:
340 raise ValueError(f'Case {specific_case!r} not found in the dataset')
342 @classmethod
343 @functools.cache
344 def _params(cls) -> tuple[type[InputsT], type[OutputT], type[MetadataT]]:
345 """Get the type parameters for the Dataset class.
347 Returns:
348 A tuple of (InputsT, OutputT, MetadataT) types.
349 """
350 for c in cls.__mro__:
351 metadata = getattr(c, '__pydantic_generic_metadata__', {})
352 if len(args := (metadata.get('args', ()) or getattr(c, '__args__', ()))) == 3: 352 ↛ 350line 352 didn't jump to line 350 because the condition on line 352 was always true
353 return args
354 else: # pragma: no cover
355 warnings.warn(
356 f'Could not determine the generic parameters for {cls}; using `Any` for each. '
357 f'You should explicitly set the generic parameters via `Dataset[MyInputs, MyOutput, MyMetadata]`'
358 f'when serializing or deserializing.',
359 UserWarning,
360 )
361 return Any, Any, Any # type: ignore
363 @classmethod
364 def from_file(
365 cls,
366 path: Path | str,
367 fmt: Literal['yaml', 'json'] | None = None,
368 custom_evaluator_types: Sequence[type[Evaluator[InputsT, OutputT, MetadataT]]] = (),
369 ) -> Self:
370 """Load a dataset from a file.
372 Args:
373 path: Path to the file to load.
374 fmt: Format of the file. If None, the format will be inferred from the file extension.
375 Must be either 'yaml' or 'json'.
376 custom_evaluator_types: Custom evaluator classes to use when deserializing the dataset.
377 These are additional evaluators beyond the default ones.
379 Returns:
380 A new Dataset instance loaded from the file.
382 Raises:
383 ValidationError: If the file cannot be parsed as a valid dataset.
384 ValueError: If the format cannot be inferred from the file extension.
385 """
386 path = Path(path)
387 fmt = cls._infer_fmt(path, fmt)
389 raw = Path(path).read_text()
390 try:
391 return cls.from_text(raw, fmt=fmt, custom_evaluator_types=custom_evaluator_types)
392 except ValidationError as e: # pragma: no cover
393 raise ValueError(f'{path} contains data that does not match the schema for {cls.__name__}:\n{e}.') from e
395 @classmethod
396 def from_text(
397 cls,
398 contents: str,
399 fmt: Literal['yaml', 'json'] = 'yaml',
400 custom_evaluator_types: Sequence[type[Evaluator[InputsT, OutputT, MetadataT]]] = (),
401 ) -> Self:
402 """Load a dataset from a string.
404 Args:
405 contents: The string content to parse.
406 fmt: Format of the content. Must be either 'yaml' or 'json'.
407 custom_evaluator_types: Custom evaluator classes to use when deserializing the dataset.
408 These are additional evaluators beyond the default ones.
410 Returns:
411 A new Dataset instance parsed from the string.
413 Raises:
414 ValidationError: If the content cannot be parsed as a valid dataset.
415 """
416 if fmt == 'yaml':
417 loaded = yaml.safe_load(contents)
418 return cls.from_dict(loaded, custom_evaluator_types)
419 else:
420 dataset_model_type = cls._serialization_type()
421 dataset_model = dataset_model_type.model_validate_json(contents)
422 return cls._from_dataset_model(dataset_model, custom_evaluator_types)
424 @classmethod
425 def from_dict(
426 cls,
427 data: dict[str, Any],
428 custom_evaluator_types: Sequence[type[Evaluator[InputsT, OutputT, MetadataT]]] = (),
429 ) -> Self:
430 """Load a dataset from a dictionary.
432 Args:
433 data: Dictionary representation of the dataset.
434 custom_evaluator_types: Custom evaluator classes to use when deserializing the dataset.
435 These are additional evaluators beyond the default ones.
437 Returns:
438 A new Dataset instance created from the dictionary.
440 Raises:
441 ValidationError: If the dictionary cannot be converted to a valid dataset.
442 """
443 dataset_model_type = cls._serialization_type()
444 dataset_model = dataset_model_type.model_validate(data)
445 return cls._from_dataset_model(dataset_model, custom_evaluator_types)
447 @classmethod
448 def _from_dataset_model(
449 cls,
450 dataset_model: _DatasetModel[InputsT, OutputT, MetadataT],
451 custom_evaluator_types: Sequence[type[Evaluator[InputsT, OutputT, MetadataT]]] = (),
452 ) -> Self:
453 """Create a Dataset from a _DatasetModel.
455 Args:
456 dataset_model: The _DatasetModel to convert.
457 custom_evaluator_types: Custom evaluator classes to register for deserialization.
459 Returns:
460 A new Dataset instance created from the _DatasetModel.
461 """
462 registry = _get_registry(custom_evaluator_types)
464 cases: list[Case[InputsT, OutputT, MetadataT]] = []
465 errors: list[ValueError] = []
466 dataset_evaluators: list[Evaluator[Any, Any, Any]] = []
467 for spec in dataset_model.evaluators:
468 try:
469 dataset_evaluator = _load_evaluator_from_registry(registry, None, spec)
470 except ValueError as e:
471 errors.append(e)
472 continue
473 dataset_evaluators.append(dataset_evaluator)
475 for row in dataset_model.cases:
476 evaluators: list[Evaluator[Any, Any, Any]] = []
477 for spec in row.evaluators:
478 try:
479 evaluator = _load_evaluator_from_registry(registry, row.name, spec)
480 except ValueError as e:
481 errors.append(e)
482 continue
483 evaluators.append(evaluator)
484 row = Case[InputsT, OutputT, MetadataT](
485 name=row.name,
486 inputs=row.inputs,
487 metadata=row.metadata,
488 expected_output=row.expected_output,
489 )
490 row.evaluators = evaluators
491 cases.append(row)
492 if errors:
493 raise ExceptionGroup(f'{len(errors)} error(s) loading evaluators from registry', errors[:3])
494 result = cls(cases=cases)
495 result.evaluators = dataset_evaluators
496 return result
498 def to_file(
499 self,
500 path: Path | str,
501 fmt: Literal['yaml', 'json'] | None = None,
502 schema_path: Path | str | None = DEFAULT_SCHEMA_PATH_TEMPLATE,
503 custom_evaluator_types: Sequence[type[Evaluator[InputsT, OutputT, MetadataT]]] = (),
504 ):
505 """Save the dataset to a file.
507 Args:
508 path: Path to save the dataset to.
509 fmt: Format to use. If None, the format will be inferred from the file extension.
510 Must be either 'yaml' or 'json'.
511 schema_path: Path to save the JSON schema to. If None, no schema will be saved.
512 Can be a string template with {stem} which will be replaced with the dataset filename stem.
513 custom_evaluator_types: Custom evaluator classes to include in the schema.
514 """
515 path = Path(path)
516 fmt = self._infer_fmt(path, fmt)
518 schema_ref: str | None = None
519 if schema_path is not None: 519 ↛ 532line 519 didn't jump to line 532 because the condition on line 519 was always true
520 if isinstance(schema_path, str): 520 ↛ 523line 520 didn't jump to line 523 because the condition on line 520 was always true
521 schema_path = Path(schema_path.format(stem=path.stem))
523 if not schema_path.is_absolute():
524 schema_ref = str(schema_path)
525 schema_path = path.parent / schema_path
526 elif schema_path.is_relative_to(path): # pragma: no cover
527 schema_ref = str(_get_relative_path_reference(schema_path, path))
528 else: # pragma: no cover
529 schema_ref = str(schema_path)
530 self._save_schema(schema_path, custom_evaluator_types)
532 context: dict[str, Any] = {'use_short_form': True}
533 if fmt == 'yaml':
534 dumped_data = self.model_dump(mode='json', by_alias=True, exclude_defaults=True, context=context)
535 content = yaml.dump(dumped_data, sort_keys=False)
536 if schema_ref: 536 ↛ 539line 536 didn't jump to line 539 because the condition on line 536 was always true
537 yaml_language_server_line = f'{_YAML_SCHEMA_LINE_PREFIX}{schema_ref}'
538 content = f'{yaml_language_server_line}\n{content}'
539 path.write_text(content)
540 else:
541 context['$schema'] = schema_ref
542 json_data = self.model_dump_json(indent=2, by_alias=True, exclude_defaults=True, context=context)
543 path.write_text(json_data + '\n')
545 @classmethod
546 def model_json_schema_with_evaluators(
547 cls,
548 custom_evaluator_types: Sequence[type[Evaluator[InputsT, OutputT, MetadataT]]] = (),
549 ) -> dict[str, Any]:
550 """Generate a JSON schema for this dataset type, including evaluator details.
552 This is useful for generating a schema that can be used to validate YAML-format dataset files.
554 Args:
555 custom_evaluator_types: Custom evaluator classes to include in the schema.
557 Returns:
558 A dictionary representing the JSON schema.
559 """
560 # Note: this function could maybe be simplified now that Evaluators are always dataclasses
561 registry = _get_registry(custom_evaluator_types)
563 evaluator_schema_types: list[Any] = []
564 for name, evaluator_class in registry.items():
565 type_hints = _typing_extra.get_function_type_hints(evaluator_class)
566 type_hints.pop('return', None)
567 required_type_hints: dict[str, Any] = {}
569 for p in inspect.signature(evaluator_class).parameters.values():
570 type_hints.setdefault(p.name, Any)
571 if p.default is not p.empty:
572 type_hints[p.name] = NotRequired[type_hints[p.name]]
573 else:
574 required_type_hints[p.name] = type_hints[p.name]
576 def _make_typed_dict(cls_name_prefix: str, fields: dict[str, Any]) -> Any:
577 td = TypedDict(f'{cls_name_prefix}_{name}', fields) # pyright: ignore[reportArgumentType]
578 config = ConfigDict(extra='forbid', arbitrary_types_allowed=True)
579 # TODO: Replace with pydantic.with_config after pydantic 2.11 is released
580 td.__pydantic_config__ = config # pyright: ignore[reportAttributeAccessIssue]
581 return td
583 # Shortest form: just the call name
584 if len(type_hints) == 0 or not required_type_hints:
585 evaluator_schema_types.append(Literal[name])
587 # Short form: can be called with only one parameter
588 if len(type_hints) == 1:
589 [type_hint_type] = type_hints.values()
590 evaluator_schema_types.append(_make_typed_dict('short_evaluator', {name: type_hint_type}))
591 elif len(required_type_hints) == 1:
592 [type_hint_type] = required_type_hints.values()
593 evaluator_schema_types.append(_make_typed_dict('short_evaluator', {name: type_hint_type}))
595 # Long form: multiple parameters, possibly required
596 if len(type_hints) > 1:
597 params_td = _make_typed_dict('evaluator_params', type_hints)
598 evaluator_schema_types.append(_make_typed_dict('evaluator', {name: params_td}))
600 in_type, out_type, meta_type = cls._params()
602 class ClsDatasetRow(BaseModel, extra='forbid'):
603 name: str
604 inputs: in_type
605 metadata: meta_type
606 expected_output: out_type | None = None
607 if evaluator_schema_types: 607 ↛ exitline 607 didn't exit class 'ClsDatasetRow' because the condition on line 607 was always true
608 evaluators: list[Union[tuple(evaluator_schema_types)]] = [] # pyright: ignore # noqa UP007
610 ClsDatasetRow.__name__ = cls.__name__ + 'Row'
612 class ClsDataset(BaseModel, extra='forbid'):
613 cases: list[ClsDatasetRow]
614 if evaluator_schema_types: 614 ↛ exitline 614 didn't exit class 'ClsDataset' because the condition on line 614 was always true
615 evaluators: list[Union[tuple(evaluator_schema_types)]] = [] # pyright: ignore # noqa UP007
617 ClsDataset.__name__ = cls.__name__
619 json_schema = ClsDataset.model_json_schema()
620 # See `_add_json_schema` below, since `$schema` is added to the JSON, it has to be supported in the JSON
621 json_schema['properties']['$schema'] = {'type': 'string'}
622 return json_schema
624 @classmethod
625 def _save_schema(
626 cls, path: Path | str, custom_evaluator_types: Sequence[type[Evaluator[InputsT, OutputT, MetadataT]]] = ()
627 ):
628 """Save the JSON schema for this dataset type to a file.
630 Args:
631 path: Path to save the schema to.
632 custom_evaluator_types: Custom evaluator classes to include in the schema.
633 """
634 path = Path(path)
635 json_schema = cls.model_json_schema_with_evaluators(custom_evaluator_types)
636 schema_content = to_json(json_schema, indent=2).decode() + '\n'
637 if not path.exists() or path.read_text() != schema_content: 637 ↛ exitline 637 didn't return from function '_save_schema' because the condition on line 637 was always true
638 path.write_text(schema_content)
640 @classmethod
641 @functools.cache
642 def _serialization_type(cls) -> type[_DatasetModel[InputsT, OutputT, MetadataT]]:
643 """Get the serialization type for this dataset class.
645 Returns:
646 A _DatasetModel type with the same generic parameters as this Dataset class.
647 """
648 input_type, output_type, metadata_type = cls._params()
649 return _DatasetModel[input_type, output_type, metadata_type]
651 @classmethod
652 def _infer_fmt(cls, path: Path, fmt: Literal['yaml', 'json'] | None) -> Literal['yaml', 'json']:
653 """Infer the format to use for a file based on its extension.
655 Args:
656 path: The path to infer the format for.
657 fmt: The explicitly provided format, if any.
659 Returns:
660 The inferred format ('yaml' or 'json').
662 Raises:
663 ValueError: If the format cannot be inferred from the file extension.
664 """
665 if fmt is not None:
666 return fmt
667 suffix = path.suffix.lower()
668 if suffix in {'.yaml', '.yml'}:
669 return 'yaml'
670 elif suffix == '.json':
671 return 'json'
672 raise ValueError(
673 f'Could not infer format for filename {path.name!r}. Use the `fmt` argument to specify the format.'
674 )
676 @model_serializer(mode='wrap')
677 def _add_json_schema(self, nxt: SerializerFunctionWrapHandler, info: SerializationInfo) -> dict[str, Any]:
678 """Add the JSON schema path to the serialized output.
680 See <https://github.com/json-schema-org/json-schema-spec/issues/828> for context, that seems to be the nearest
681 there is to a spec for this.
682 """
683 context = cast(Union[dict[str, Any], None], info.context)
684 if isinstance(context, dict) and (schema := context.get('$schema')):
685 return {'$schema': schema} | nxt(self)
686 else:
687 return nxt(self)
690def _get_relative_path_reference(target: Path, source: Path, _prefix: str = '') -> Path: # pragma: no cover
691 """Get a relative path reference from source to target.
693 Recursively resolve a relative path to target from source, adding '..' as needed.
694 This is useful for creating a relative path reference from a source file to a target file.
696 Args:
697 target: The target path to reference.
698 source: The source path to reference from.
699 _prefix: Internal prefix used during recursion.
701 Returns:
702 A Path object representing the relative path from source to target.
704 Example:
705 If source is '/a/b/c.py' and target is '/a/d/e.py', the relative path reference
706 would be '../../d/e.py'.
707 """
708 # Recursively resolve a relative path to target from source, adding '..' as needed.
709 # This is useful for creating a relative path reference from a source file to a target file.
710 # For example, if source is '/a/b/c.py' and target is '/a/d/e.py', the relative path reference
711 # would be '../../d/e.py'.
712 if not target.is_absolute():
713 target = target.resolve()
714 try:
715 return Path(f'{_prefix}{Path(target).relative_to(source)}')
716 except ValueError:
717 return _get_relative_path_reference(target, source.parent, _prefix=f'{_prefix}../')
720@dataclass
721class _TaskRun:
722 """Internal class to track metrics and attributes for a task run."""
724 attributes: dict[str, Any] = field(init=False, default_factory=dict)
725 metrics: dict[str, int | float] = field(init=False, default_factory=dict)
727 def record_metric(self, name: str, value: int | float) -> None:
728 """Record a metric value.
730 Args:
731 name: The name of the metric.
732 value: The value of the metric.
733 """
734 self.metrics[name] = value
736 def increment_metric(self, name: str, amount: int | float) -> None:
737 """Increment a metric value.
739 Args:
740 name: The name of the metric.
741 amount: The amount to increment by.
743 Note:
744 If the current value is 0 and the increment amount is 0, no metric will be recorded.
745 """
746 current_value = self.metrics.get(name, 0)
747 incremented_value = current_value + amount
748 if current_value == 0 and incremented_value == 0:
749 return # Avoid recording a metric that is always zero
750 self.record_metric(name, incremented_value)
752 def record_attribute(self, name: str, value: Any) -> None:
753 """Record an attribute value.
755 Args:
756 name: The name of the attribute.
757 value: The value of the attribute.
758 """
759 self.attributes[name] = value
762async def _run_task(
763 task: Callable[[InputsT], Awaitable[OutputT]], case: Case[InputsT, OutputT, MetadataT]
764) -> EvaluatorContext[InputsT, OutputT, MetadataT]:
765 """Run a task on a case and return the context for evaluators.
767 Args:
768 task: The task to run.
769 case: The case to run the task on.
771 Returns:
772 An EvaluatorContext containing the inputs, actual output, expected output, and metadata.
774 Raises:
775 Exception: Any exception raised by the task.
776 """
777 task_run = _TaskRun()
778 if _CURRENT_TASK_RUN.get() is not None: # pragma: no cover
779 raise RuntimeError('A task run has already been entered. Task runs should not be nested')
781 # Note: the current behavior is for task execution errors to just bubble up all the way and kill the evaluation.
782 # Should we handle them for the user in some way? If so, I guess we'd want to do that here.
783 token = _CURRENT_TASK_RUN.set(task_run)
784 try:
785 with _logfire.span('execute {task}', task=get_unwrapped_function_name(task)) as task_span:
786 with context_subtree() as span_tree:
787 t0 = time.time()
788 task_output = await task(case.inputs)
789 fallback_duration = time.time() - t0
790 finally:
791 _CURRENT_TASK_RUN.reset(token)
793 if isinstance(span_tree, SpanTree): 793 ↛ 811line 793 didn't jump to line 811 because the condition on line 793 was always true
794 # TODO: Question: Should we make this metric-attributes functionality more user-configurable in some way before merging?
795 # Note: the use of otel for collecting these metrics is the main reason why I think we should require at least otel as a dependency, if not logfire;
796 # otherwise, we don't have a great way to get usage data from arbitrary frameworks.
797 # Ideally we wouldn't need to hard-code the specific logic here, but I'm not sure a great way to expose it to
798 # users. Maybe via an argument of type Callable[[SpanTree], dict[str, int | float]] or similar?
799 for node in span_tree.flattened():
800 if node.attributes.get('gen_ai.operation.name') == 'chat':
801 task_run.increment_metric('requests', 1)
802 for k, v in node.attributes.items():
803 if not isinstance(v, (int, float)):
804 continue
805 # TODO: Revisit this choice to strip the prefix..
806 if k.startswith('gen_ai.usage.details.'):
807 task_run.increment_metric(k[21:], v)
808 elif k.startswith('gen_ai.usage.'):
809 task_run.increment_metric(k[13:], v)
811 return EvaluatorContext[InputsT, OutputT, MetadataT](
812 name=case.name,
813 inputs=case.inputs,
814 metadata=case.metadata,
815 expected_output=case.expected_output,
816 output=task_output,
817 duration=_get_span_duration(task_span, fallback_duration),
818 _span_tree=span_tree,
819 attributes=task_run.attributes,
820 metrics=task_run.metrics,
821 )
824async def _run_task_and_evaluators(
825 task: Callable[[InputsT], Awaitable[OutputT]],
826 case: Case[InputsT, OutputT, MetadataT],
827 report_case_name: str,
828 dataset_evaluators: list[Evaluator[InputsT, OutputT, MetadataT]],
829) -> ReportCase:
830 """Run a task on a case and evaluate the results.
832 Args:
833 task: The task to run.
834 case: The case to run the task on.
835 report_case_name: The name to use for this case in the report.
836 dataset_evaluators: Evaluators from the dataset to apply to this case.
838 Returns:
839 A ReportCase containing the evaluation results.
840 """
841 with _logfire.span(
842 '{task_name}: {case_name}',
843 task_name=get_unwrapped_function_name(task),
844 case_name=case.name,
845 inputs=case.inputs,
846 metadata=case.metadata,
847 ) as case_span:
848 t0 = time.time()
849 scoring_context = await _run_task(task, case)
851 case_span.set_attribute('output', scoring_context.output)
852 case_span.set_attribute('task_duration', scoring_context.duration)
853 case_span.set_attribute('metrics', scoring_context.metrics)
854 case_span.set_attribute('attributes', scoring_context.attributes)
856 evaluators = case.evaluators + dataset_evaluators
857 evaluator_outputs: list[EvaluationResult] = []
858 if evaluators:
859 evaluator_outputs_by_task = await task_group_gather(
860 [lambda ev=ev: run_evaluator(ev, scoring_context) for ev in evaluators]
861 )
862 evaluator_outputs += [out for outputs in evaluator_outputs_by_task for out in outputs]
864 assertions, scores, labels = _group_evaluator_outputs_by_type(evaluator_outputs)
865 case_span.set_attribute('assertions', _evaluation_results_adapter.dump_python(assertions))
866 case_span.set_attribute('scores', _evaluation_results_adapter.dump_python(scores))
867 case_span.set_attribute('labels', _evaluation_results_adapter.dump_python(labels))
869 context = case_span.context
870 if context is None: # pragma: no cover
871 trace_id = ''
872 span_id = ''
873 else:
874 trace_id = f'{context.trace_id:032x}'
875 span_id = f'{context.span_id:016x}'
876 fallback_duration = time.time() - t0
878 report_inputs = to_jsonable_python(case.inputs)
880 return ReportCase(
881 name=report_case_name,
882 inputs=report_inputs,
883 metadata=case.metadata,
884 expected_output=case.expected_output,
885 output=scoring_context.output,
886 metrics=scoring_context.metrics,
887 attributes=scoring_context.attributes,
888 scores=scores,
889 labels=labels,
890 assertions=assertions,
891 task_duration=scoring_context.duration,
892 total_duration=_get_span_duration(case_span, fallback_duration),
893 trace_id=trace_id,
894 span_id=span_id,
895 )
898_evaluation_results_adapter = TypeAdapter(Mapping[str, EvaluationResult])
901def _group_evaluator_outputs_by_type(
902 evaluation_results: Sequence[EvaluationResult],
903) -> tuple[
904 dict[str, EvaluationResult[bool]],
905 dict[str, EvaluationResult[int | float]],
906 dict[str, EvaluationResult[str]],
907]:
908 """Group evaluator outputs by their result type.
910 Args:
911 evaluation_results: Sequence of evaluation results to group.
913 Returns:
914 A tuple of dictionaries mapping evaluator names to their results, grouped by result type:
915 (success_evaluations, metric_evaluations, string_evaluations)
916 """
917 assertions: dict[str, EvaluationResult[bool]] = {}
918 scores: dict[str, EvaluationResult[int | float]] = {}
919 labels: dict[str, EvaluationResult[str]] = {}
920 seen_names = set[str]()
921 for er in evaluation_results:
922 name = er.name
923 # Dedupe repeated names by adding a numeric suffix
924 if name in seen_names:
925 suffix = 2
926 while f'{name}_{suffix}' in seen_names:
927 suffix += 1
928 name = f'{name}_{suffix}'
929 seen_names.add(name)
930 if assertion := er.downcast(bool):
931 assertions[name] = assertion
932 elif score := er.downcast(int, float):
933 scores[name] = score
934 elif label := er.downcast(str): 934 ↛ 921line 934 didn't jump to line 921 because the condition on line 934 was always true
935 labels[name] = label
936 return assertions, scores, labels
939_CURRENT_TASK_RUN = ContextVar['_TaskRun | None']('_CURRENT_TASK_RUN', default=None)
942def set_eval_attribute(name: str, value: Any) -> None:
943 """Set an attribute on the current task run.
945 Args:
946 name: The name of the attribute.
947 value: The value of the attribute.
948 """
949 current_case = _CURRENT_TASK_RUN.get()
950 if current_case is not None: 950 ↛ exitline 950 didn't return from function 'set_eval_attribute' because the condition on line 950 was always true
951 current_case.record_attribute(name, value)
954def increment_eval_metric(name: str, amount: int | float) -> None:
955 """Increment a metric on the current task run.
957 Args:
958 name: The name of the metric.
959 amount: The amount to increment by.
960 """
961 current_case = _CURRENT_TASK_RUN.get()
962 if current_case is not None: 962 ↛ exitline 962 didn't return from function 'increment_eval_metric' because the condition on line 962 was always true
963 current_case.increment_metric(name, amount)
966def _get_span_duration(span: logfire_api.LogfireSpan, fallback: float) -> float:
967 """Calculate the duration of a span in seconds.
969 We prefer to obtain the duration from a span for the sake of consistency with observability and to make
970 the values more reliable during testing. However, if the span is not available (e.g. when using logfire_api
971 without logfire installed), we fall back to the provided duration.
973 Args:
974 span: The span to calculate the duration for.
975 fallback: The fallback duration to use if unable to obtain the duration from the span.
977 Returns:
978 The duration of the span in seconds.
979 """
980 try:
981 return (span.end_time - span.start_time) / 1_000_000_000 # type: ignore
982 except (AttributeError, TypeError): # pragma: no cover
983 return fallback
986def _get_registry(
987 custom_evaluator_types: Sequence[type[Evaluator[InputsT, OutputT, MetadataT]]],
988) -> Mapping[str, type[Evaluator[InputsT, OutputT, MetadataT]]]:
989 """Create a registry of evaluator types from default and custom evaluators.
991 Args:
992 custom_evaluator_types: Additional evaluator classes to include in the registry.
994 Returns:
995 A mapping from evaluator names to evaluator classes.
996 """
997 registry: dict[str, type[Evaluator[InputsT, OutputT, MetadataT]]] = {}
999 for evaluator_class in custom_evaluator_types:
1000 if not issubclass(evaluator_class, Evaluator):
1001 raise ValueError(
1002 f'All custom evaluator classes must be subclasses of Evaluator, but {evaluator_class} is not'
1003 )
1004 if '__dataclass_fields__' not in evaluator_class.__dict__:
1005 raise ValueError(
1006 f'All custom evaluator classes must be decorated with `@dataclass`, but {evaluator_class} is not'
1007 )
1008 name = evaluator_class.name()
1009 if name in registry:
1010 raise ValueError(f'Duplicate evaluator class name: {name!r}')
1011 registry[name] = evaluator_class
1013 for evaluator_class in DEFAULT_EVALUATORS:
1014 # Allow overriding the default evaluators with custom evaluators raising an error
1015 registry.setdefault(evaluator_class.name(), evaluator_class)
1017 return registry
1020def _load_evaluator_from_registry(
1021 registry: Mapping[str, type[Evaluator[InputsT, OutputT, MetadataT]]],
1022 case_name: str | None,
1023 spec: EvaluatorSpec,
1024) -> Evaluator[InputsT, OutputT, MetadataT]:
1025 """Load an evaluator from the registry based on a specification.
1027 Args:
1028 registry: Mapping from evaluator names to evaluator classes.
1029 case_name: Name of the case this evaluator will be used for, or None for dataset-level evaluators.
1030 spec: Specification of the evaluator to load.
1032 Returns:
1033 An initialized evaluator instance.
1035 Raises:
1036 ValueError: If the evaluator name is not found in the registry.
1037 """
1038 evaluator_class = registry.get(spec.name)
1039 if evaluator_class is None:
1040 raise ValueError(
1041 f'Evaluator {spec.name!r} is not in the provided registry. Registered choices: {list(registry.keys())}'
1042 )
1043 try:
1044 return evaluator_class(*spec.args, **spec.kwargs)
1045 except Exception as e:
1046 case_detail = f'case {case_name!r}' if case_name is not None else 'dataset'
1047 raise ValueError(f'Failed to instantiate evaluator {spec.name!r} for {case_detail}: {e}') from e