Coverage for pydantic_evals/pydantic_evals/reporting/__init__.py: 96.95%
490 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-28 17:27 +0000
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-28 17:27 +0000
1from __future__ import annotations as _annotations
3from collections import defaultdict
4from collections.abc import Mapping
5from dataclasses import dataclass, field
6from io import StringIO
7from typing import Any, Callable, Literal, Protocol, TypeVar
9from pydantic import BaseModel
10from rich.console import Console
11from rich.table import Table
12from typing_extensions import TypedDict
14from pydantic_evals._utils import UNSET, Unset
16from ..evaluators import EvaluationResult
17from .render_numbers import (
18 default_render_duration,
19 default_render_duration_diff,
20 default_render_number,
21 default_render_number_diff,
22 default_render_percentage,
23)
25__all__ = (
26 'EvaluationReport',
27 'ReportCase',
28 'EvaluationRenderer',
29 'RenderValueConfig',
30 'RenderNumberConfig',
31 'ReportCaseAggregate',
32)
34MISSING_VALUE_STR = '[i]<missing>[/i]'
35EMPTY_CELL_STR = '-'
36EMPTY_AGGREGATE_CELL_STR = ''
39class ReportCase(BaseModel):
40 """A single case in an evaluation report."""
42 name: str
43 """The name of the [case][pydantic_evals.Case]."""
44 inputs: Any
45 """The inputs to the task, from [`Case.inputs`][pydantic_evals.Case.inputs]."""
46 metadata: Any
47 """Any metadata associated with the case, from [`Case.metadata`][pydantic_evals.Case.metadata]."""
48 expected_output: Any
49 """The expected output of the task, from [`Case.expected_output`][pydantic_evals.Case.expected_output]."""
50 output: Any
51 """The output of the task execution."""
53 metrics: dict[str, float | int]
54 attributes: dict[str, Any]
56 scores: dict[str, EvaluationResult[int | float]] = field(init=False)
57 labels: dict[str, EvaluationResult[str]] = field(init=False)
58 assertions: dict[str, EvaluationResult[bool]] = field(init=False)
60 task_duration: float
61 total_duration: float # includes evaluator execution time
63 # TODO(DavidM): Drop these once we can reference child spans in details panel:
64 trace_id: str
65 span_id: str
68class ReportCaseAggregate(BaseModel):
69 """A synthetic case that summarizes a set of cases."""
71 name: str
73 scores: dict[str, float | int]
74 labels: dict[str, dict[str, float]]
75 metrics: dict[str, float | int]
76 assertions: float | None
77 task_duration: float
78 total_duration: float
80 @staticmethod
81 def average(cases: list[ReportCase]) -> ReportCaseAggregate:
82 """Produce a synthetic "summary" case by averaging quantitative attributes."""
83 num_cases = len(cases)
84 if num_cases == 0:
85 return ReportCaseAggregate(
86 name='Averages',
87 scores={},
88 labels={},
89 metrics={},
90 assertions=None,
91 task_duration=0.0,
92 total_duration=0.0,
93 )
95 def _scores_averages(scores_by_name: list[dict[str, int | float | bool]]) -> dict[str, float]:
96 counts_by_name: dict[str, int] = defaultdict(int)
97 sums_by_name: dict[str, float] = defaultdict(float)
98 for sbn in scores_by_name:
99 for name, score in sbn.items():
100 counts_by_name[name] += 1
101 sums_by_name[name] += score
102 return {name: sums_by_name[name] / counts_by_name[name] for name in sums_by_name}
104 def _labels_averages(labels_by_name: list[dict[str, str]]) -> dict[str, dict[str, float]]:
105 counts_by_name: dict[str, int] = defaultdict(int)
106 sums_by_name: dict[str, dict[str, float]] = defaultdict(lambda: defaultdict(float))
107 for lbn in labels_by_name:
108 for name, label in lbn.items():
109 counts_by_name[name] += 1
110 sums_by_name[name][label] += 1
111 return {
112 name: {value: count / counts_by_name[name] for value, count in sums_by_name[name].items()}
113 for name in sums_by_name
114 }
116 average_task_duration = sum(case.task_duration for case in cases) / num_cases
117 average_total_duration = sum(case.total_duration for case in cases) / num_cases
119 # average_assertions: dict[str, float] = _scores_averages([{k: v.value for k, v in case.scores.items()} for case in cases])
120 average_scores: dict[str, float] = _scores_averages(
121 [{k: v.value for k, v in case.scores.items()} for case in cases]
122 )
123 average_labels: dict[str, dict[str, float]] = _labels_averages(
124 [{k: v.value for k, v in case.labels.items()} for case in cases]
125 )
126 average_metrics: dict[str, float] = _scores_averages([case.metrics for case in cases])
128 average_assertions: float | None = None
129 n_assertions = sum(len(case.assertions) for case in cases)
130 if n_assertions > 0:
131 n_passing = sum(1 for case in cases for assertion in case.assertions.values() if assertion.value)
132 average_assertions = n_passing / n_assertions
134 return ReportCaseAggregate(
135 name='Averages',
136 scores=average_scores,
137 labels=average_labels,
138 metrics=average_metrics,
139 assertions=average_assertions,
140 task_duration=average_task_duration,
141 total_duration=average_total_duration,
142 )
145class EvaluationReport(BaseModel):
146 """A report of the results of evaluating a model on a set of cases."""
148 name: str
149 """The name of the report."""
150 cases: list[ReportCase]
151 """The cases in the report."""
153 def print(
154 self,
155 width: int | None = None,
156 baseline: EvaluationReport | None = None,
157 include_input: bool = False,
158 include_output: bool = False,
159 include_total_duration: bool = False,
160 include_removed_cases: bool = False,
161 include_averages: bool = True,
162 input_config: RenderValueConfig | None = None,
163 output_config: RenderValueConfig | None = None,
164 score_configs: dict[str, RenderNumberConfig] | None = None,
165 label_configs: dict[str, RenderValueConfig] | None = None,
166 metric_configs: dict[str, RenderNumberConfig] | None = None,
167 duration_config: RenderNumberConfig | None = None,
168 ): # pragma: no cover
169 """Print this report to the console, optionally comparing it to a baseline report.
171 If you want more control over the output, use `console_table` instead and pass it to `rich.Console.print`.
172 """
173 table = self.console_table(
174 baseline=baseline,
175 include_input=include_input,
176 include_output=include_output,
177 include_total_duration=include_total_duration,
178 include_removed_cases=include_removed_cases,
179 include_averages=include_averages,
180 input_config=input_config,
181 output_config=output_config,
182 score_configs=score_configs,
183 label_configs=label_configs,
184 metric_configs=metric_configs,
185 duration_config=duration_config,
186 )
187 Console(width=width).print(table)
189 def console_table(
190 self,
191 baseline: EvaluationReport | None = None,
192 include_input: bool = False,
193 include_output: bool = False,
194 include_total_duration: bool = False,
195 include_removed_cases: bool = False,
196 include_averages: bool = True,
197 input_config: RenderValueConfig | None = None,
198 output_config: RenderValueConfig | None = None,
199 score_configs: dict[str, RenderNumberConfig] | None = None,
200 label_configs: dict[str, RenderValueConfig] | None = None,
201 metric_configs: dict[str, RenderNumberConfig] | None = None,
202 duration_config: RenderNumberConfig | None = None,
203 ) -> Table:
204 """Return a table containing the data from this report, or the diff between this report and a baseline report.
206 Optionally include input and output details.
207 """
208 renderer = EvaluationRenderer(
209 include_input=include_input,
210 include_output=include_output,
211 include_total_duration=include_total_duration,
212 include_removed_cases=include_removed_cases,
213 include_averages=include_averages,
214 input_config={**_DEFAULT_VALUE_CONFIG, **(input_config or {})},
215 output_config=output_config or _DEFAULT_VALUE_CONFIG,
216 score_configs=score_configs or {},
217 label_configs=label_configs or {},
218 metric_configs=metric_configs or {},
219 duration_config=duration_config or _DEFAULT_DURATION_CONFIG,
220 )
221 if baseline is None:
222 return renderer.build_table(self)
223 else: # pragma: no cover
224 return renderer.build_diff_table(self, baseline)
226 def __str__(self) -> str:
227 """Return a string representation of the report."""
228 table = self.console_table()
229 io_file = StringIO()
230 Console(file=io_file).print(table)
231 return io_file.getvalue()
234class RenderValueConfig(TypedDict, total=False):
235 """A configuration for rendering a values in an Evaluation report."""
237 value_formatter: str | Callable[[Any], str]
238 diff_checker: Callable[[Any, Any], bool] | None
239 diff_formatter: Callable[[Any, Any], str | None] | None
240 diff_style: str
243@dataclass
244class _ValueRenderer:
245 value_formatter: str | Callable[[Any], str] = '{}'
246 diff_checker: Callable[[Any, Any], bool] | None = lambda x, y: x != y
247 diff_formatter: Callable[[Any, Any], str | None] | None = None
248 diff_style: str = 'magenta'
250 @staticmethod
251 def from_config(config: RenderValueConfig) -> _ValueRenderer:
252 return _ValueRenderer(
253 value_formatter=config.get('value_formatter', '{}'),
254 diff_checker=config.get('diff_checker', lambda x, y: x != y),
255 diff_formatter=config.get('diff_formatter'),
256 diff_style=config.get('diff_style', 'magenta'),
257 )
259 def render_value(self, name: str | None, v: Any) -> str:
260 result = self._get_value_str(v)
261 if name:
262 result = f'{name}: {result}'
263 return result
265 def render_diff(self, name: str | None, old: Any | None, new: Any | None) -> str:
266 old_str = self._get_value_str(old) or MISSING_VALUE_STR
267 new_str = self._get_value_str(new) or MISSING_VALUE_STR
268 if old_str == new_str:
269 result = old_str
270 else:
271 result = f'{old_str} → {new_str}'
273 has_diff = self.diff_checker and self.diff_checker(old, new)
274 if has_diff: 274 ↛ 283line 274 didn't jump to line 283 because the condition on line 274 was always true
275 # If there is a diff, make the name bold and compute the diff_str
276 name = name and f'[bold]{name}[/]'
277 diff_str = self.diff_formatter and self.diff_formatter(old, new)
278 if diff_str: # pragma: no cover
279 result += f' ({diff_str})'
280 result = f'[{self.diff_style}]{result}[/]'
282 # Add the name
283 if name:
284 result = f'{name}: {result}'
286 return result
288 def _get_value_str(self, value: Any) -> str:
289 if value is None:
290 return MISSING_VALUE_STR
291 if isinstance(self.value_formatter, str):
292 return self.value_formatter.format(value)
293 else:
294 return self.value_formatter(value)
297class RenderNumberConfig(TypedDict, total=False):
298 """A configuration for rendering a particular score or metric in an Evaluation report.
300 See the implementation of `_RenderNumber` for more clarity on how these parameters affect the rendering.
301 """
303 value_formatter: str | Callable[[float | int], str]
304 """The logic to use for formatting values.
306 * If not provided, format as ints if all values are ints, otherwise at least one decimal place and at least four significant figures.
307 * You can also use a custom string format spec, e.g. '{:.3f}'
308 * You can also use a custom function, e.g. lambda x: f'{x:.3f}'
309 """
310 diff_formatter: str | Callable[[float | int, float | int], str | None] | None
311 """The logic to use for formatting details about the diff.
313 The strings produced by the value_formatter will always be included in the reports, but the diff_formatter is
314 used to produce additional text about the difference between the old and new values, such as the absolute or
315 relative difference.
317 * If not provided, format as ints if all values are ints, otherwise at least one decimal place and at least four
318 significant figures, and will include the percentage change.
319 * You can also use a custom string format spec, e.g. '{:+.3f}'
320 * You can also use a custom function, e.g. lambda x: f'{x:+.3f}'.
321 If this function returns None, no extra diff text will be added.
322 * You can also use None to never generate extra diff text.
323 """
324 diff_atol: float
325 """The absolute tolerance for considering a difference "significant".
327 A difference is "significant" if `abs(new - old) < self.diff_atol + self.diff_rtol * abs(old)`.
329 If a difference is not significant, it will not have the diff styles applied. Note that we still show
330 both the rendered before and after values in the diff any time they differ, even if the difference is not
331 significant. (If the rendered values are exactly the same, we only show the value once.)
333 If not provided, use 1e-6.
334 """
335 diff_rtol: float
336 """The relative tolerance for considering a difference "significant".
338 See the description of `diff_atol` for more details about what makes a difference "significant".
340 If not provided, use 0.001 if all values are ints, otherwise 0.05.
341 """
342 diff_increase_style: str
343 """The style to apply to diffed values that have a significant increase.
345 See the description of `diff_atol` for more details about what makes a difference "significant".
347 If not provided, use green for scores and red for metrics. You can also use arbitrary `rich` styles, such as "bold red".
348 """
349 diff_decrease_style: str
350 """The style to apply to diffed values that have significant decrease.
352 See the description of `diff_atol` for more details about what makes a difference "significant".
354 If not provided, use red for scores and green for metrics. You can also use arbitrary `rich` styles, such as "bold red".
355 """
358@dataclass
359class _NumberRenderer:
360 """See documentation of `RenderNumberConfig` for more details about the parameters here."""
362 value_formatter: str | Callable[[float | int], str]
363 diff_formatter: str | Callable[[float | int, float | int], str | None] | None
364 diff_atol: float
365 diff_rtol: float
366 diff_increase_style: str
367 diff_decrease_style: str
369 def render_value(self, name: str | None, v: float | int) -> str:
370 result = self._get_value_str(v)
371 if name:
372 result = f'{name}: {result}'
373 return result
375 def render_diff(self, name: str | None, old: float | int | None, new: float | int | None) -> str:
376 old_str = self._get_value_str(old)
377 new_str = self._get_value_str(new)
378 if old_str == new_str:
379 result = old_str
380 else:
381 result = f'{old_str} → {new_str}'
383 diff_style = self._get_diff_style(old, new)
384 if diff_style:
385 # If there is a diff, make the name bold and compute the diff_str
386 name = name and f'[bold]{name}[/]'
387 diff_str = self._get_diff_str(old, new)
388 if diff_str: 388 ↛ 390line 388 didn't jump to line 390 because the condition on line 388 was always true
389 result += f' ({diff_str})'
390 result = f'[{diff_style}]{result}[/]'
392 # Add the name
393 if name: 393 ↛ 396line 393 didn't jump to line 396 because the condition on line 393 was always true
394 result = f'{name}: {result}'
396 return result
398 @staticmethod
399 def infer_from_config(
400 config: RenderNumberConfig, kind: Literal['score', 'metric', 'duration'], values: list[float | int]
401 ) -> _NumberRenderer:
402 value_formatter = config.get('value_formatter', UNSET)
403 if isinstance(value_formatter, Unset):
404 value_formatter = default_render_number
406 diff_formatter = config.get('diff_formatter', UNSET)
407 if isinstance(diff_formatter, Unset):
408 diff_formatter = default_render_number_diff
410 diff_atol = config.get('diff_atol', UNSET)
411 if isinstance(diff_atol, Unset):
412 diff_atol = 1e-6
414 diff_rtol = config.get('diff_rtol', UNSET)
415 if isinstance(diff_rtol, Unset):
416 values_are_ints = all(isinstance(v, int) for v in values)
417 diff_rtol = 0.001 if values_are_ints else 0.05
419 diff_increase_style = config.get('diff_increase_style', UNSET)
420 if isinstance(diff_increase_style, Unset):
421 diff_increase_style = 'green' if kind == 'score' else 'red'
423 diff_decrease_style = config.get('diff_decrease_style', UNSET)
424 if isinstance(diff_decrease_style, Unset):
425 diff_decrease_style = 'red' if kind == 'score' else 'green'
427 return _NumberRenderer(
428 value_formatter=value_formatter,
429 diff_formatter=diff_formatter,
430 diff_rtol=diff_rtol,
431 diff_atol=diff_atol,
432 diff_increase_style=diff_increase_style,
433 diff_decrease_style=diff_decrease_style,
434 )
436 def _get_value_str(self, value: float | int | None) -> str:
437 if value is None:
438 return MISSING_VALUE_STR
439 if isinstance(self.value_formatter, str):
440 return self.value_formatter.format(value)
441 else:
442 return self.value_formatter(value)
444 def _get_diff_str(self, old: float | int | None, new: float | int | None) -> str | None:
445 if old is None or new is None: # pragma: no cover
446 return None
447 if isinstance(self.diff_formatter, str): # pragma: no cover
448 return self.diff_formatter.format(new - old)
449 elif self.diff_formatter is None: # pragma: no cover
450 return None
451 else:
452 return self.diff_formatter(old, new)
454 def _get_diff_style(self, old: float | int | None, new: float | int | None) -> str | None:
455 # 1 means new is higher, -1 means new is lower, 0 means no change
456 if old is None or new is None:
457 return None
459 diff = new - old
460 if abs(diff) < self.diff_atol + self.diff_rtol * abs(old): # pragma: no cover
461 return None
462 return self.diff_increase_style if diff > 0 else self.diff_decrease_style
465T_contra = TypeVar('T_contra', contravariant=True)
468class _AbstractRenderer(Protocol[T_contra]):
469 def render_value(self, name: str | None, v: T_contra) -> str: ... 469 ↛ exitline 469 didn't return from function 'render_value' because
471 def render_diff(self, name: str | None, old: T_contra | None, new: T_contra | None) -> str: ... 471 ↛ exitline 471 didn't return from function 'render_diff' because
474_DEFAULT_NUMBER_CONFIG = RenderNumberConfig()
475_DEFAULT_VALUE_CONFIG = RenderValueConfig()
476_DEFAULT_DURATION_CONFIG = RenderNumberConfig(
477 value_formatter=default_render_duration,
478 diff_formatter=default_render_duration_diff,
479 diff_atol=1e-6, # one microsecond
480 diff_rtol=0.1,
481 diff_increase_style='red',
482 diff_decrease_style='green',
483)
486T = TypeVar('T')
489@dataclass
490class ReportCaseRenderer:
491 include_input: bool
492 include_output: bool
493 include_scores: bool
494 include_labels: bool
495 include_metrics: bool
496 include_assertions: bool
497 include_total_duration: bool
499 input_renderer: _ValueRenderer
500 output_renderer: _ValueRenderer
501 score_renderers: dict[str, _NumberRenderer]
502 label_renderers: dict[str, _ValueRenderer]
503 metric_renderers: dict[str, _NumberRenderer]
504 duration_renderer: _NumberRenderer
506 def build_base_table(self, title: str) -> Table:
507 """Build and return a Rich Table for the diff output."""
508 table = Table(title=title, show_lines=True)
509 table.add_column('Case ID', style='bold')
510 if self.include_input:
511 table.add_column('Inputs', overflow='fold')
512 if self.include_output:
513 table.add_column('Outputs', overflow='fold')
514 if self.include_scores:
515 table.add_column('Scores', overflow='fold')
516 if self.include_labels:
517 table.add_column('Labels', overflow='fold')
518 if self.include_metrics:
519 table.add_column('Metrics', overflow='fold')
520 if self.include_assertions:
521 table.add_column('Assertions', overflow='fold')
522 table.add_column('Durations' if self.include_total_duration else 'Duration', justify='right')
523 return table
525 def build_row(self, case: ReportCase) -> list[str]:
526 """Build a table row for a single case."""
527 row = [case.name]
529 if self.include_input:
530 row.append(self.input_renderer.render_value(None, case.inputs) or EMPTY_CELL_STR)
532 if self.include_output:
533 row.append(self.output_renderer.render_value(None, case.output) or EMPTY_CELL_STR)
535 if self.include_scores:
536 row.append(self._render_dict({k: v.value for k, v in case.scores.items()}, self.score_renderers))
538 if self.include_labels:
539 row.append(self._render_dict({k: v.value for k, v in case.labels.items()}, self.label_renderers))
541 if self.include_metrics:
542 row.append(self._render_dict(case.metrics, self.metric_renderers))
544 if self.include_assertions:
545 row.append(self._render_assertions(list(case.assertions.values())))
547 row.append(self._render_durations(case))
548 return row
550 def build_aggregate_row(self, aggregate: ReportCaseAggregate) -> list[str]:
551 """Build a table row for an aggregated case."""
552 row = [f'[b i]{aggregate.name}[/]']
554 if self.include_input:
555 row.append(EMPTY_AGGREGATE_CELL_STR)
557 if self.include_output:
558 row.append(EMPTY_AGGREGATE_CELL_STR)
560 if self.include_scores:
561 row.append(self._render_dict(aggregate.scores, self.score_renderers))
563 if self.include_labels:
564 row.append(self._render_dict(aggregate.labels, self.label_renderers))
566 if self.include_metrics:
567 row.append(self._render_dict(aggregate.metrics, self.metric_renderers))
569 if self.include_assertions:
570 row.append(self._render_aggregate_assertions(aggregate.assertions))
572 row.append(self._render_durations(aggregate))
573 return row
575 def build_diff_row(
576 self,
577 new_case: ReportCase,
578 baseline: ReportCase,
579 ) -> list[str]:
580 """Build a table row for a given case ID."""
581 assert baseline.name == new_case.name, 'This should only be called for matching case IDs'
582 row = [baseline.name]
584 if self.include_input: 584 ↛ 588line 584 didn't jump to line 588 because the condition on line 584 was always true
585 input_diff = self.input_renderer.render_diff(None, baseline.inputs, new_case.inputs) or EMPTY_CELL_STR
586 row.append(input_diff)
588 if self.include_output: 588 ↛ 592line 588 didn't jump to line 592 because the condition on line 588 was always true
589 output_diff = self.output_renderer.render_diff(None, baseline.output, new_case.output) or EMPTY_CELL_STR
590 row.append(output_diff)
592 if self.include_scores: 592 ↛ 600line 592 didn't jump to line 600 because the condition on line 592 was always true
593 scores_diff = self._render_dicts_diff(
594 {k: v.value for k, v in baseline.scores.items()},
595 {k: v.value for k, v in new_case.scores.items()},
596 self.score_renderers,
597 )
598 row.append(scores_diff)
600 if self.include_labels: 600 ↛ 604line 600 didn't jump to line 604 because the condition on line 600 was always true
601 labels_diff = self._render_dicts_diff(baseline.labels, new_case.labels, self.label_renderers)
602 row.append(labels_diff)
604 if self.include_metrics: 604 ↛ 608line 604 didn't jump to line 608 because the condition on line 604 was always true
605 metrics_diff = self._render_dicts_diff(baseline.metrics, new_case.metrics, self.metric_renderers)
606 row.append(metrics_diff)
608 if self.include_assertions: 608 ↛ 614line 608 didn't jump to line 614 because the condition on line 608 was always true
609 assertions_diff = self._render_assertions_diff(
610 list(baseline.assertions.values()), list(new_case.assertions.values())
611 )
612 row.append(assertions_diff)
614 row.append(self._render_durations_diff(baseline, new_case))
616 return row
618 def build_diff_aggregate_row(
619 self,
620 new: ReportCaseAggregate,
621 baseline: ReportCaseAggregate,
622 ) -> list[str]:
623 """Build a table row for a given case ID."""
624 assert baseline.name == new.name, 'This should only be called for aggregates with matching names'
625 row = [f'[b i]{baseline.name}[/]']
627 if self.include_input: 627 ↛ 630line 627 didn't jump to line 630 because the condition on line 627 was always true
628 row.append(EMPTY_AGGREGATE_CELL_STR)
630 if self.include_output: 630 ↛ 633line 630 didn't jump to line 633 because the condition on line 630 was always true
631 row.append(EMPTY_AGGREGATE_CELL_STR)
633 if self.include_scores: 633 ↛ 637line 633 didn't jump to line 637 because the condition on line 633 was always true
634 scores_diff = self._render_dicts_diff(baseline.scores, new.scores, self.score_renderers)
635 row.append(scores_diff)
637 if self.include_labels: 637 ↛ 641line 637 didn't jump to line 641 because the condition on line 637 was always true
638 labels_diff = self._render_dicts_diff(baseline.labels, new.labels, self.label_renderers)
639 row.append(labels_diff)
641 if self.include_metrics: 641 ↛ 645line 641 didn't jump to line 645 because the condition on line 641 was always true
642 metrics_diff = self._render_dicts_diff(baseline.metrics, new.metrics, self.metric_renderers)
643 row.append(metrics_diff)
645 if self.include_assertions: 645 ↛ 649line 645 didn't jump to line 649 because the condition on line 645 was always true
646 assertions_diff = self._render_aggregate_assertions_diff(baseline.assertions, new.assertions)
647 row.append(assertions_diff)
649 row.append(self._render_durations_diff(baseline, new))
651 return row
653 def _render_durations(self, case: ReportCase | ReportCaseAggregate) -> str:
654 """Build the diff string for a duration value."""
655 case_durations: dict[str, float] = {'task': case.task_duration}
656 if self.include_total_duration:
657 case_durations['total'] = case.total_duration
658 return self._render_dict(
659 case_durations,
660 {'task': self.duration_renderer, 'total': self.duration_renderer},
661 include_names=self.include_total_duration,
662 )
664 def _render_durations_diff(
665 self,
666 base_case: ReportCase | ReportCaseAggregate,
667 new_case: ReportCase | ReportCaseAggregate,
668 ) -> str:
669 """Build the diff string for a duration value."""
670 base_case_durations: dict[str, float] = {'task': base_case.task_duration}
671 new_case_durations: dict[str, float] = {'task': new_case.task_duration}
672 if self.include_total_duration: 672 ↛ 675line 672 didn't jump to line 675 because the condition on line 672 was always true
673 base_case_durations['total'] = base_case.total_duration
674 new_case_durations['total'] = new_case.total_duration
675 return self._render_dicts_diff(
676 base_case_durations,
677 new_case_durations,
678 {'task': self.duration_renderer, 'total': self.duration_renderer},
679 include_names=self.include_total_duration,
680 )
682 @staticmethod
683 def _render_dicts_diff(
684 baseline_dict: dict[str, T],
685 new_dict: dict[str, T],
686 renderers: Mapping[str, _AbstractRenderer[T]],
687 *,
688 include_names: bool = True,
689 ) -> str:
690 keys: set[str] = set()
691 keys.update(baseline_dict.keys())
692 keys.update(new_dict.keys())
693 diff_lines: list[str] = []
694 for key in sorted(keys):
695 old_val = baseline_dict.get(key)
696 new_val = new_dict.get(key)
697 rendered = renderers[key].render_diff(key if include_names else None, old_val, new_val)
698 diff_lines.append(rendered)
699 return '\n'.join(diff_lines) if diff_lines else EMPTY_CELL_STR
701 @staticmethod
702 def _render_dict(
703 case_dict: dict[str, T],
704 renderers: Mapping[str, _AbstractRenderer[T]],
705 *,
706 include_names: bool = True,
707 ) -> str:
708 diff_lines: list[str] = []
709 for key, val in case_dict.items():
710 rendered = renderers[key].render_value(key if include_names else None, val)
711 diff_lines.append(rendered)
712 return '\n'.join(diff_lines) if diff_lines else EMPTY_CELL_STR
714 @staticmethod
715 def _render_assertions(
716 assertions: list[EvaluationResult[bool]],
717 ) -> str:
718 if not assertions:
719 return EMPTY_CELL_STR
720 return ''.join(['[green]✔[/]' if a.value else '[red]✗[/]' for a in assertions])
722 @staticmethod
723 def _render_aggregate_assertions(
724 assertions: float | None,
725 ) -> str:
726 return (
727 default_render_percentage(assertions) + ' [green]✔[/]'
728 if assertions is not None
729 else EMPTY_AGGREGATE_CELL_STR
730 )
732 @staticmethod
733 def _render_assertions_diff(
734 assertions: list[EvaluationResult[bool]], new_assertions: list[EvaluationResult[bool]]
735 ) -> str:
736 if not assertions and not new_assertions: # pragma: no cover
737 return EMPTY_CELL_STR
739 old = ''.join(['[green]✔[/]' if a.value else '[red]✗[/]' for a in assertions])
740 new = ''.join(['[green]✔[/]' if a.value else '[red]✗[/]' for a in new_assertions])
741 return old if old == new else f'{old} → {new}'
743 @staticmethod
744 def _render_aggregate_assertions_diff(
745 baseline: float | None,
746 new: float | None,
747 ) -> str:
748 if baseline is None and new is None: # pragma: no cover
749 return EMPTY_AGGREGATE_CELL_STR
750 rendered_baseline = (
751 default_render_percentage(baseline) + ' [green]✔[/]' if baseline is not None else EMPTY_CELL_STR
752 )
753 rendered_new = default_render_percentage(new) + ' [green]✔[/]' if new is not None else EMPTY_CELL_STR
754 return rendered_new if rendered_baseline == rendered_new else f'{rendered_baseline} → {rendered_new}'
757@dataclass
758class EvaluationRenderer:
759 """A class for rendering an EvalReport or the diff between two EvalReports."""
761 # Columns to include
762 include_input: bool
763 include_output: bool
764 include_total_duration: bool
766 # Rows to include
767 include_removed_cases: bool
768 include_averages: bool
770 input_config: RenderValueConfig
771 output_config: RenderValueConfig
772 score_configs: dict[str, RenderNumberConfig]
773 label_configs: dict[str, RenderValueConfig]
774 metric_configs: dict[str, RenderNumberConfig]
775 duration_config: RenderNumberConfig
777 def include_scores(self, report: EvaluationReport, baseline: EvaluationReport | None = None):
778 return any(case.scores for case in self._all_cases(report, baseline))
780 def include_labels(self, report: EvaluationReport, baseline: EvaluationReport | None = None):
781 return any(case.labels for case in self._all_cases(report, baseline))
783 def include_metrics(self, report: EvaluationReport, baseline: EvaluationReport | None = None):
784 return any(case.metrics for case in self._all_cases(report, baseline))
786 def include_assertions(self, report: EvaluationReport, baseline: EvaluationReport | None = None):
787 return any(case.assertions for case in self._all_cases(report, baseline))
789 def _all_cases(self, report: EvaluationReport, baseline: EvaluationReport | None) -> list[ReportCase]:
790 if not baseline:
791 return report.cases
792 else:
793 return report.cases + self._baseline_cases_to_include(report, baseline)
795 def _baseline_cases_to_include(self, report: EvaluationReport, baseline: EvaluationReport) -> list[ReportCase]:
796 if self.include_removed_cases:
797 return baseline.cases
798 report_case_names = {case.name for case in report.cases}
799 return [case for case in baseline.cases if case.name in report_case_names]
801 def _get_case_renderer(
802 self, report: EvaluationReport, baseline: EvaluationReport | None = None
803 ) -> ReportCaseRenderer:
804 input_renderer = _ValueRenderer.from_config(self.input_config)
805 output_renderer = _ValueRenderer.from_config(self.output_config)
806 score_renderers = self._infer_score_renderers(report, baseline)
807 label_renderers = self._infer_label_renderers(report, baseline)
808 metric_renderers = self._infer_metric_renderers(report, baseline)
809 duration_renderer = _NumberRenderer.infer_from_config(
810 self.duration_config, 'duration', [x.task_duration for x in self._all_cases(report, baseline)]
811 )
813 return ReportCaseRenderer(
814 include_input=self.include_input,
815 include_output=self.include_output,
816 include_scores=self.include_scores(report, baseline),
817 include_labels=self.include_labels(report, baseline),
818 include_metrics=self.include_metrics(report, baseline),
819 include_assertions=self.include_assertions(report, baseline),
820 include_total_duration=self.include_total_duration,
821 input_renderer=input_renderer,
822 output_renderer=output_renderer,
823 score_renderers=score_renderers,
824 label_renderers=label_renderers,
825 metric_renderers=metric_renderers,
826 duration_renderer=duration_renderer,
827 )
829 def build_table(self, report: EvaluationReport) -> Table:
830 case_renderer = self._get_case_renderer(report)
831 table = case_renderer.build_base_table(f'Evaluation Summary: {report.name}')
832 for case in report.cases:
833 table.add_row(*case_renderer.build_row(case))
835 if self.include_averages: 835 ↛ 838line 835 didn't jump to line 838 because the condition on line 835 was always true
836 average = ReportCaseAggregate.average(report.cases)
837 table.add_row(*case_renderer.build_aggregate_row(average))
838 return table
840 def build_diff_table(self, report: EvaluationReport, baseline: EvaluationReport) -> Table:
841 report_cases = report.cases
842 baseline_cases = self._baseline_cases_to_include(report, baseline)
844 report_cases_by_id = {case.name: case for case in report_cases}
845 baseline_cases_by_id = {case.name: case for case in baseline_cases}
847 diff_cases: list[tuple[ReportCase, ReportCase]] = []
848 removed_cases: list[ReportCase] = []
849 added_cases: list[ReportCase] = []
851 for case_id in sorted(set(baseline_cases_by_id.keys()) | set(report_cases_by_id.keys())):
852 maybe_baseline_case = baseline_cases_by_id.get(case_id)
853 maybe_report_case = report_cases_by_id.get(case_id)
854 if maybe_baseline_case and maybe_report_case:
855 diff_cases.append((maybe_baseline_case, maybe_report_case))
856 elif maybe_baseline_case:
857 removed_cases.append(maybe_baseline_case)
858 elif maybe_report_case:
859 added_cases.append(maybe_report_case)
860 else: # pragma: no cover
861 assert False, 'This should be unreachable'
863 case_renderer = self._get_case_renderer(report, baseline)
864 diff_name = baseline.name if baseline.name == report.name else f'{baseline.name} → {report.name}'
865 table = case_renderer.build_base_table(f'Evaluation Diff: {diff_name}')
866 for baseline_case, new_case in diff_cases:
867 table.add_row(*case_renderer.build_diff_row(new_case, baseline_case))
868 for case in added_cases:
869 row = case_renderer.build_row(case)
870 row[0] = f'[green]+ Added Case[/]\n{row[0]}'
871 table.add_row(*row)
872 for case in removed_cases:
873 row = case_renderer.build_row(case)
874 row[0] = f'[red]- Removed Case[/]\n{row[0]}'
875 table.add_row(*row)
877 if self.include_averages: 877 ↛ 882line 877 didn't jump to line 882 because the condition on line 877 was always true
878 report_average = ReportCaseAggregate.average(report_cases)
879 baseline_average = ReportCaseAggregate.average(baseline_cases)
880 table.add_row(*case_renderer.build_diff_aggregate_row(report_average, baseline_average))
882 return table
884 def _infer_score_renderers(
885 self, report: EvaluationReport, baseline: EvaluationReport | None
886 ) -> dict[str, _NumberRenderer]:
887 all_cases = self._all_cases(report, baseline)
889 values_by_name: dict[str, list[float | int]] = {}
890 for case in all_cases:
891 for k, score in case.scores.items():
892 values_by_name.setdefault(k, []).append(score.value)
894 all_renderers: dict[str, _NumberRenderer] = {}
895 for name, values in values_by_name.items():
896 merged_config = _DEFAULT_NUMBER_CONFIG.copy()
897 merged_config.update(self.score_configs.get(name, {}))
898 all_renderers[name] = _NumberRenderer.infer_from_config(merged_config, 'score', values)
899 return all_renderers
901 def _infer_label_renderers(
902 self, report: EvaluationReport, baseline: EvaluationReport | None
903 ) -> dict[str, _ValueRenderer]:
904 all_cases = self._all_cases(report, baseline)
905 all_names: set[str] = set()
906 for case in all_cases:
907 for k in case.labels:
908 all_names.add(k)
910 all_renderers: dict[str, _ValueRenderer] = {}
911 for name in all_names:
912 merged_config = _DEFAULT_VALUE_CONFIG.copy()
913 merged_config.update(self.label_configs.get(name, {}))
914 all_renderers[name] = _ValueRenderer.from_config(merged_config)
915 return all_renderers
917 def _infer_metric_renderers(
918 self, report: EvaluationReport, baseline: EvaluationReport | None
919 ) -> dict[str, _NumberRenderer]:
920 all_cases = self._all_cases(report, baseline)
922 values_by_name: dict[str, list[float | int]] = {}
923 for case in all_cases:
924 for k, v in case.metrics.items():
925 values_by_name.setdefault(k, []).append(v)
927 all_renderers: dict[str, _NumberRenderer] = {}
928 for name, values in values_by_name.items():
929 merged_config = _DEFAULT_NUMBER_CONFIG.copy()
930 merged_config.update(self.metric_configs.get(name, {}))
931 all_renderers[name] = _NumberRenderer.infer_from_config(merged_config, 'metric', values)
932 return all_renderers
934 def _infer_duration_renderer(
935 self, report: EvaluationReport, baseline: EvaluationReport | None
936 ) -> _NumberRenderer: # pragma: no cover
937 all_cases = self._all_cases(report, baseline)
938 all_durations = [x.task_duration for x in all_cases]
939 if self.include_total_duration:
940 all_durations += [x.total_duration for x in all_cases]
941 return _NumberRenderer.infer_from_config(self.duration_config, 'duration', all_durations)