Coverage for tests/conftest.py: 98.98%
96 statements
« prev ^ index » next coverage.py v7.6.7, created at 2025-01-25 16:43 +0000
« prev ^ index » next coverage.py v7.6.7, created at 2025-01-25 16:43 +0000
1from __future__ import annotations as _annotations
3import asyncio
4import importlib.util
5import os
6import re
7import secrets
8import sys
9from collections.abc import AsyncIterator, Iterator
10from contextlib import contextmanager
11from datetime import datetime
12from pathlib import Path
13from types import ModuleType
14from typing import TYPE_CHECKING, Any, Callable
16import httpx
17import pytest
18from _pytest.assertion.rewrite import AssertionRewritingHook
19from typing_extensions import TypeAlias
21import pydantic_ai.models
23__all__ = 'IsNow', 'IsFloat', 'TestEnv', 'ClientWithHandler', 'try_import'
26pydantic_ai.models.ALLOW_MODEL_REQUESTS = False
28if TYPE_CHECKING:
30 def IsNow(*args: Any, **kwargs: Any) -> datetime: ...
31 def IsFloat(*args: Any, **kwargs: Any) -> float: ...
32else:
33 from dirty_equals import IsFloat, IsNow as _IsNow
35 def IsNow(*args: Any, **kwargs: Any):
36 # Increase the default value of `delta` to 10 to reduce test flakiness on overburdened machines
37 if 'delta' not in kwargs: 37 ↛ 39line 37 didn't jump to line 39 because the condition on line 37 was always true
38 kwargs['delta'] = 10
39 return _IsNow(*args, **kwargs)
42try:
43 from logfire.testing import CaptureLogfire
44except ImportError:
45 pass
46else:
48 @pytest.fixture(autouse=True)
49 def logfire_disable(capfire: CaptureLogfire):
50 pass
53class TestEnv:
54 __test__ = False
56 def __init__(self):
57 self.envars: dict[str, str | None] = {}
59 def set(self, name: str, value: str) -> None:
60 self.envars[name] = os.getenv(name)
61 os.environ[name] = value
63 def remove(self, name: str) -> None:
64 self.envars[name] = os.environ.pop(name, None)
66 def reset(self) -> None: # pragma: no cover
67 for name, value in self.envars.items():
68 if value is None:
69 os.environ.pop(name, None)
70 else:
71 os.environ[name] = value
74@pytest.fixture
75def env() -> Iterator[TestEnv]:
76 test_env = TestEnv()
78 yield test_env
80 test_env.reset()
83@pytest.fixture
84def anyio_backend():
85 return 'asyncio'
88@pytest.fixture
89def allow_model_requests():
90 with pydantic_ai.models.override_allow_model_requests(True):
91 yield
94@pytest.fixture
95async def client_with_handler() -> AsyncIterator[ClientWithHandler]:
96 client: httpx.AsyncClient | None = None
98 def create_client(handler: Callable[[httpx.Request], httpx.Response]) -> httpx.AsyncClient:
99 nonlocal client
100 assert client is None, 'client_with_handler can only be called once'
101 client = httpx.AsyncClient(mounts={'all://': httpx.MockTransport(handler)})
102 return client
104 try:
105 yield create_client
106 finally:
107 if client: # pragma: no cover
108 await client.aclose()
111ClientWithHandler: TypeAlias = Callable[[Callable[[httpx.Request], httpx.Response]], httpx.AsyncClient]
114# pyright: reportUnknownMemberType=false, reportUnknownArgumentType=false
115@pytest.fixture
116def create_module(tmp_path: Path, request: pytest.FixtureRequest) -> Callable[[str], Any]:
117 """Taken from `pydantic/tests/conftest.py`, create module object, execute and return it."""
119 def run(
120 source_code: str,
121 rewrite_assertions: bool = True,
122 module_name_prefix: str | None = None,
123 ) -> ModuleType:
124 """Create module object, execute and return it.
126 Can be used as a decorator of the function from the source code of which the module will be constructed.
128 Args:
129 source_code: Python source code of the module
130 rewrite_assertions: whether to rewrite assertions in module or not
131 module_name_prefix: string prefix to use in the name of the module, does not affect the name of the file.
133 """
135 # Max path length in Windows is 260. Leaving some buffer here
136 max_name_len = 240 - len(str(tmp_path))
137 # Windows does not allow these characters in paths. Linux bans slashes only.
138 sanitized_name = re.sub('[' + re.escape('<>:"/\\|?*') + ']', '-', request.node.name)[:max_name_len]
139 module_name = f'{sanitized_name}_{secrets.token_hex(5)}'
140 path = tmp_path / f'{module_name}.py'
141 path.write_text(source_code)
142 filename = str(path)
144 if module_name_prefix: # pragma: no cover
145 module_name = module_name_prefix + module_name
147 if rewrite_assertions:
148 loader = AssertionRewritingHook(config=request.config)
149 loader.mark_rewrite(module_name)
150 else: # pragma: no cover
151 loader = None
153 spec = importlib.util.spec_from_file_location(module_name, filename, loader=loader)
154 sys.modules[module_name] = module = importlib.util.module_from_spec(spec) # pyright: ignore[reportArgumentType]
155 spec.loader.exec_module(module) # pyright: ignore[reportOptionalMemberAccess]
156 return module
158 return run
161@contextmanager
162def try_import() -> Iterator[Callable[[], bool]]:
163 import_success = False
165 def check_import() -> bool:
166 return import_success
168 try:
169 yield check_import
170 except ImportError:
171 pass
172 else:
173 import_success = True
176@pytest.fixture(autouse=True)
177def set_event_loop() -> Iterator[None]:
178 new_loop = asyncio.new_event_loop()
179 asyncio.set_event_loop(new_loop)
180 yield
181 new_loop.close()