Coverage for pydantic_ai_slim/pydantic_ai/_pydantic.py: 98.59%

104 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2025-01-25 16:43 +0000

1"""Used to build pydantic validators and JSON schemas from functions. 

2 

3This module has to use numerous internal Pydantic APIs and is therefore brittle to changes in Pydantic. 

4""" 

5 

6from __future__ import annotations as _annotations 

7 

8from inspect import Parameter, signature 

9from typing import TYPE_CHECKING, Any, Callable, TypedDict, cast, get_origin 

10 

11from pydantic import ConfigDict 

12from pydantic._internal import _decorators, _generate_schema, _typing_extra 

13from pydantic._internal._config import ConfigWrapper 

14from pydantic.fields import FieldInfo 

15from pydantic.json_schema import GenerateJsonSchema 

16from pydantic.plugin._schema_validator import create_schema_validator 

17from pydantic_core import SchemaValidator, core_schema 

18 

19from ._griffe import doc_descriptions 

20from ._utils import check_object_json_schema, is_model_like 

21 

22if TYPE_CHECKING: 

23 from .tools import DocstringFormat, ObjectJsonSchema 

24 

25 

26__all__ = ('function_schema',) 

27 

28 

29class FunctionSchema(TypedDict): 

30 """Internal information about a function schema.""" 

31 

32 description: str 

33 validator: SchemaValidator 

34 json_schema: ObjectJsonSchema 

35 # if not None, the function takes a single by that name (besides potentially `info`) 

36 single_arg_name: str | None 

37 positional_fields: list[str] 

38 var_positional_field: str | None 

39 

40 

41def function_schema( # noqa: C901 

42 function: Callable[..., Any], 

43 takes_ctx: bool, 

44 docstring_format: DocstringFormat, 

45 require_parameter_descriptions: bool, 

46) -> FunctionSchema: 

47 """Build a Pydantic validator and JSON schema from a tool function. 

48 

49 Args: 

50 function: The function to build a validator and JSON schema for. 

51 takes_ctx: Whether the function takes a `RunContext` first argument. 

52 docstring_format: The docstring format to use. 

53 require_parameter_descriptions: Whether to require descriptions for all tool function parameters. 

54 

55 Returns: 

56 A `FunctionSchema` instance. 

57 """ 

58 config = ConfigDict(title=function.__name__) 

59 config_wrapper = ConfigWrapper(config) 

60 gen_schema = _generate_schema.GenerateSchema(config_wrapper) 

61 

62 sig = signature(function) 

63 

64 type_hints = _typing_extra.get_function_type_hints(function) 

65 

66 var_kwargs_schema: core_schema.CoreSchema | None = None 

67 fields: dict[str, core_schema.TypedDictField] = {} 

68 positional_fields: list[str] = [] 

69 var_positional_field: str | None = None 

70 errors: list[str] = [] 

71 decorators = _decorators.DecoratorInfos() 

72 

73 description, field_descriptions = doc_descriptions(function, sig, docstring_format=docstring_format) 

74 

75 if require_parameter_descriptions: 

76 if len(field_descriptions) != len(sig.parameters): 76 ↛ 80line 76 didn't jump to line 80 because the condition on line 76 was always true

77 missing_params = set(sig.parameters) - set(field_descriptions) 

78 errors.append(f'Missing parameter descriptions for {", ".join(missing_params)}') 

79 

80 for index, (name, p) in enumerate(sig.parameters.items()): 

81 if p.annotation is sig.empty: 

82 if takes_ctx and index == 0: 

83 # should be the `context` argument, skip 

84 continue 

85 # TODO warn? 

86 annotation = Any 

87 else: 

88 annotation = type_hints[name] 

89 

90 if index == 0 and takes_ctx: 

91 if not _is_call_ctx(annotation): 

92 errors.append('First parameter of tools that take context must be annotated with RunContext[...]') 

93 continue 

94 elif not takes_ctx and _is_call_ctx(annotation): 

95 errors.append('RunContext annotations can only be used with tools that take context') 

96 continue 

97 elif index != 0 and _is_call_ctx(annotation): 

98 errors.append('RunContext annotations can only be used as the first argument') 

99 continue 

100 

101 field_name = p.name 

102 if p.kind == Parameter.VAR_KEYWORD: 

103 var_kwargs_schema = gen_schema.generate_schema(annotation) 

104 else: 

105 if p.kind == Parameter.VAR_POSITIONAL: 

106 annotation = list[annotation] 

107 

108 # FieldInfo.from_annotation expects a type, `annotation` is Any 

109 annotation = cast(type[Any], annotation) 

110 field_info = FieldInfo.from_annotation(annotation) 

111 if field_info.description is None: 

112 field_info.description = field_descriptions.get(field_name) 

113 

114 fields[field_name] = td_schema = gen_schema._generate_td_field_schema( # pyright: ignore[reportPrivateUsage] 

115 field_name, 

116 field_info, 

117 decorators, 

118 required=p.default is Parameter.empty, 

119 ) 

120 # noinspection PyTypeChecker 

121 td_schema.setdefault('metadata', {})['is_model_like'] = is_model_like(annotation) 

122 

123 if p.kind == Parameter.POSITIONAL_ONLY: 

124 positional_fields.append(field_name) 

125 elif p.kind == Parameter.VAR_POSITIONAL: 

126 var_positional_field = field_name 

127 

128 if errors: 

129 from .exceptions import UserError 

130 

131 error_details = '\n '.join(errors) 

132 raise UserError(f'Error generating schema for {function.__qualname__}:\n {error_details}') 

133 

134 core_config = config_wrapper.core_config(None) 

135 # noinspection PyTypedDict 

136 core_config['extra_fields_behavior'] = 'allow' if var_kwargs_schema else 'forbid' 

137 

138 schema, single_arg_name = _build_schema(fields, var_kwargs_schema, gen_schema, core_config) 

139 schema = gen_schema.clean_schema(schema) 

140 # noinspection PyUnresolvedReferences 

141 schema_validator = create_schema_validator( 

142 schema, 

143 function, 

144 function.__module__, 

145 function.__qualname__, 

146 'validate_call', 

147 core_config, 

148 config_wrapper.plugin_settings, 

149 ) 

150 # PluggableSchemaValidator is api compatible with SchemaValidator 

151 schema_validator = cast(SchemaValidator, schema_validator) 

152 json_schema = GenerateJsonSchema().generate(schema) 

153 

154 # workaround for https://github.com/pydantic/pydantic/issues/10785 

155 # if we build a custom TypeDict schema (matches when `single_arg_name is None`), we manually set 

156 # `additionalProperties` in the JSON Schema 

157 if single_arg_name is None: 

158 json_schema['additionalProperties'] = bool(var_kwargs_schema) 

159 elif not description: 159 ↛ 164line 159 didn't jump to line 164 because the condition on line 159 was always true

160 # if the tool description is not set, and we have a single parameter, take the description from that 

161 # and set it on the tool 

162 description = json_schema.pop('description', None) 

163 

164 return FunctionSchema( 

165 description=description, 

166 validator=schema_validator, 

167 json_schema=check_object_json_schema(json_schema), 

168 single_arg_name=single_arg_name, 

169 positional_fields=positional_fields, 

170 var_positional_field=var_positional_field, 

171 ) 

172 

173 

174def takes_ctx(function: Callable[..., Any]) -> bool: 

175 """Check if a function takes a `RunContext` first argument. 

176 

177 Args: 

178 function: The function to check. 

179 

180 Returns: 

181 `True` if the function takes a `RunContext` as first argument, `False` otherwise. 

182 """ 

183 sig = signature(function) 

184 try: 

185 first_param_name = next(iter(sig.parameters.keys())) 

186 except StopIteration: 

187 return False 

188 else: 

189 type_hints = _typing_extra.get_function_type_hints(function) 

190 annotation = type_hints[first_param_name] 

191 return annotation is not sig.empty and _is_call_ctx(annotation) 

192 

193 

194def _build_schema( 

195 fields: dict[str, core_schema.TypedDictField], 

196 var_kwargs_schema: core_schema.CoreSchema | None, 

197 gen_schema: _generate_schema.GenerateSchema, 

198 core_config: core_schema.CoreConfig, 

199) -> tuple[core_schema.CoreSchema, str | None]: 

200 """Generate a typed dict schema for function parameters. 

201 

202 Args: 

203 fields: The fields to generate a typed dict schema for. 

204 var_kwargs_schema: The variable keyword arguments schema. 

205 gen_schema: The `GenerateSchema` instance. 

206 core_config: The core configuration. 

207 

208 Returns: 

209 tuple of (generated core schema, single arg name). 

210 """ 

211 if len(fields) == 1 and var_kwargs_schema is None: 

212 name = next(iter(fields)) 

213 td_field = fields[name] 

214 if td_field['metadata']['is_model_like']: # type: ignore 

215 return td_field['schema'], name 

216 

217 td_schema = core_schema.typed_dict_schema( 

218 fields, 

219 config=core_config, 

220 extras_schema=gen_schema.generate_schema(var_kwargs_schema) if var_kwargs_schema else None, 

221 ) 

222 return td_schema, None 

223 

224 

225def _is_call_ctx(annotation: Any) -> bool: 

226 from .tools import RunContext 

227 

228 return annotation is RunContext or ( 

229 _typing_extra.is_generic_alias(annotation) and get_origin(annotation) is RunContext 

230 )