Coverage for pydantic_ai_slim/pydantic_ai/_pydantic.py: 98.59%
104 statements
« prev ^ index » next coverage.py v7.6.7, created at 2025-01-25 16:43 +0000
« prev ^ index » next coverage.py v7.6.7, created at 2025-01-25 16:43 +0000
1"""Used to build pydantic validators and JSON schemas from functions.
3This module has to use numerous internal Pydantic APIs and is therefore brittle to changes in Pydantic.
4"""
6from __future__ import annotations as _annotations
8from inspect import Parameter, signature
9from typing import TYPE_CHECKING, Any, Callable, TypedDict, cast, get_origin
11from pydantic import ConfigDict
12from pydantic._internal import _decorators, _generate_schema, _typing_extra
13from pydantic._internal._config import ConfigWrapper
14from pydantic.fields import FieldInfo
15from pydantic.json_schema import GenerateJsonSchema
16from pydantic.plugin._schema_validator import create_schema_validator
17from pydantic_core import SchemaValidator, core_schema
19from ._griffe import doc_descriptions
20from ._utils import check_object_json_schema, is_model_like
22if TYPE_CHECKING:
23 from .tools import DocstringFormat, ObjectJsonSchema
26__all__ = ('function_schema',)
29class FunctionSchema(TypedDict):
30 """Internal information about a function schema."""
32 description: str
33 validator: SchemaValidator
34 json_schema: ObjectJsonSchema
35 # if not None, the function takes a single by that name (besides potentially `info`)
36 single_arg_name: str | None
37 positional_fields: list[str]
38 var_positional_field: str | None
41def function_schema( # noqa: C901
42 function: Callable[..., Any],
43 takes_ctx: bool,
44 docstring_format: DocstringFormat,
45 require_parameter_descriptions: bool,
46) -> FunctionSchema:
47 """Build a Pydantic validator and JSON schema from a tool function.
49 Args:
50 function: The function to build a validator and JSON schema for.
51 takes_ctx: Whether the function takes a `RunContext` first argument.
52 docstring_format: The docstring format to use.
53 require_parameter_descriptions: Whether to require descriptions for all tool function parameters.
55 Returns:
56 A `FunctionSchema` instance.
57 """
58 config = ConfigDict(title=function.__name__)
59 config_wrapper = ConfigWrapper(config)
60 gen_schema = _generate_schema.GenerateSchema(config_wrapper)
62 sig = signature(function)
64 type_hints = _typing_extra.get_function_type_hints(function)
66 var_kwargs_schema: core_schema.CoreSchema | None = None
67 fields: dict[str, core_schema.TypedDictField] = {}
68 positional_fields: list[str] = []
69 var_positional_field: str | None = None
70 errors: list[str] = []
71 decorators = _decorators.DecoratorInfos()
73 description, field_descriptions = doc_descriptions(function, sig, docstring_format=docstring_format)
75 if require_parameter_descriptions:
76 if len(field_descriptions) != len(sig.parameters): 76 ↛ 80line 76 didn't jump to line 80 because the condition on line 76 was always true
77 missing_params = set(sig.parameters) - set(field_descriptions)
78 errors.append(f'Missing parameter descriptions for {", ".join(missing_params)}')
80 for index, (name, p) in enumerate(sig.parameters.items()):
81 if p.annotation is sig.empty:
82 if takes_ctx and index == 0:
83 # should be the `context` argument, skip
84 continue
85 # TODO warn?
86 annotation = Any
87 else:
88 annotation = type_hints[name]
90 if index == 0 and takes_ctx:
91 if not _is_call_ctx(annotation):
92 errors.append('First parameter of tools that take context must be annotated with RunContext[...]')
93 continue
94 elif not takes_ctx and _is_call_ctx(annotation):
95 errors.append('RunContext annotations can only be used with tools that take context')
96 continue
97 elif index != 0 and _is_call_ctx(annotation):
98 errors.append('RunContext annotations can only be used as the first argument')
99 continue
101 field_name = p.name
102 if p.kind == Parameter.VAR_KEYWORD:
103 var_kwargs_schema = gen_schema.generate_schema(annotation)
104 else:
105 if p.kind == Parameter.VAR_POSITIONAL:
106 annotation = list[annotation]
108 # FieldInfo.from_annotation expects a type, `annotation` is Any
109 annotation = cast(type[Any], annotation)
110 field_info = FieldInfo.from_annotation(annotation)
111 if field_info.description is None:
112 field_info.description = field_descriptions.get(field_name)
114 fields[field_name] = td_schema = gen_schema._generate_td_field_schema( # pyright: ignore[reportPrivateUsage]
115 field_name,
116 field_info,
117 decorators,
118 required=p.default is Parameter.empty,
119 )
120 # noinspection PyTypeChecker
121 td_schema.setdefault('metadata', {})['is_model_like'] = is_model_like(annotation)
123 if p.kind == Parameter.POSITIONAL_ONLY:
124 positional_fields.append(field_name)
125 elif p.kind == Parameter.VAR_POSITIONAL:
126 var_positional_field = field_name
128 if errors:
129 from .exceptions import UserError
131 error_details = '\n '.join(errors)
132 raise UserError(f'Error generating schema for {function.__qualname__}:\n {error_details}')
134 core_config = config_wrapper.core_config(None)
135 # noinspection PyTypedDict
136 core_config['extra_fields_behavior'] = 'allow' if var_kwargs_schema else 'forbid'
138 schema, single_arg_name = _build_schema(fields, var_kwargs_schema, gen_schema, core_config)
139 schema = gen_schema.clean_schema(schema)
140 # noinspection PyUnresolvedReferences
141 schema_validator = create_schema_validator(
142 schema,
143 function,
144 function.__module__,
145 function.__qualname__,
146 'validate_call',
147 core_config,
148 config_wrapper.plugin_settings,
149 )
150 # PluggableSchemaValidator is api compatible with SchemaValidator
151 schema_validator = cast(SchemaValidator, schema_validator)
152 json_schema = GenerateJsonSchema().generate(schema)
154 # workaround for https://github.com/pydantic/pydantic/issues/10785
155 # if we build a custom TypeDict schema (matches when `single_arg_name is None`), we manually set
156 # `additionalProperties` in the JSON Schema
157 if single_arg_name is None:
158 json_schema['additionalProperties'] = bool(var_kwargs_schema)
159 elif not description: 159 ↛ 164line 159 didn't jump to line 164 because the condition on line 159 was always true
160 # if the tool description is not set, and we have a single parameter, take the description from that
161 # and set it on the tool
162 description = json_schema.pop('description', None)
164 return FunctionSchema(
165 description=description,
166 validator=schema_validator,
167 json_schema=check_object_json_schema(json_schema),
168 single_arg_name=single_arg_name,
169 positional_fields=positional_fields,
170 var_positional_field=var_positional_field,
171 )
174def takes_ctx(function: Callable[..., Any]) -> bool:
175 """Check if a function takes a `RunContext` first argument.
177 Args:
178 function: The function to check.
180 Returns:
181 `True` if the function takes a `RunContext` as first argument, `False` otherwise.
182 """
183 sig = signature(function)
184 try:
185 first_param_name = next(iter(sig.parameters.keys()))
186 except StopIteration:
187 return False
188 else:
189 type_hints = _typing_extra.get_function_type_hints(function)
190 annotation = type_hints[first_param_name]
191 return annotation is not sig.empty and _is_call_ctx(annotation)
194def _build_schema(
195 fields: dict[str, core_schema.TypedDictField],
196 var_kwargs_schema: core_schema.CoreSchema | None,
197 gen_schema: _generate_schema.GenerateSchema,
198 core_config: core_schema.CoreConfig,
199) -> tuple[core_schema.CoreSchema, str | None]:
200 """Generate a typed dict schema for function parameters.
202 Args:
203 fields: The fields to generate a typed dict schema for.
204 var_kwargs_schema: The variable keyword arguments schema.
205 gen_schema: The `GenerateSchema` instance.
206 core_config: The core configuration.
208 Returns:
209 tuple of (generated core schema, single arg name).
210 """
211 if len(fields) == 1 and var_kwargs_schema is None:
212 name = next(iter(fields))
213 td_field = fields[name]
214 if td_field['metadata']['is_model_like']: # type: ignore
215 return td_field['schema'], name
217 td_schema = core_schema.typed_dict_schema(
218 fields,
219 config=core_config,
220 extras_schema=gen_schema.generate_schema(var_kwargs_schema) if var_kwargs_schema else None,
221 )
222 return td_schema, None
225def _is_call_ctx(annotation: Any) -> bool:
226 from .tools import RunContext
228 return annotation is RunContext or (
229 _typing_extra.is_generic_alias(annotation) and get_origin(annotation) is RunContext
230 )