Coverage for tests/conftest.py: 98.98%

96 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2025-01-25 16:43 +0000

1from __future__ import annotations as _annotations 

2 

3import asyncio 

4import importlib.util 

5import os 

6import re 

7import secrets 

8import sys 

9from collections.abc import AsyncIterator, Iterator 

10from contextlib import contextmanager 

11from datetime import datetime 

12from pathlib import Path 

13from types import ModuleType 

14from typing import TYPE_CHECKING, Any, Callable 

15 

16import httpx 

17import pytest 

18from _pytest.assertion.rewrite import AssertionRewritingHook 

19from typing_extensions import TypeAlias 

20 

21import pydantic_ai.models 

22 

23__all__ = 'IsNow', 'IsFloat', 'TestEnv', 'ClientWithHandler', 'try_import' 

24 

25 

26pydantic_ai.models.ALLOW_MODEL_REQUESTS = False 

27 

28if TYPE_CHECKING: 

29 

30 def IsNow(*args: Any, **kwargs: Any) -> datetime: ... 

31 def IsFloat(*args: Any, **kwargs: Any) -> float: ... 

32else: 

33 from dirty_equals import IsFloat, IsNow as _IsNow 

34 

35 def IsNow(*args: Any, **kwargs: Any): 

36 # Increase the default value of `delta` to 10 to reduce test flakiness on overburdened machines 

37 if 'delta' not in kwargs: 37 ↛ 39line 37 didn't jump to line 39 because the condition on line 37 was always true

38 kwargs['delta'] = 10 

39 return _IsNow(*args, **kwargs) 

40 

41 

42try: 

43 from logfire.testing import CaptureLogfire 

44except ImportError: 

45 pass 

46else: 

47 

48 @pytest.fixture(autouse=True) 

49 def logfire_disable(capfire: CaptureLogfire): 

50 pass 

51 

52 

53class TestEnv: 

54 __test__ = False 

55 

56 def __init__(self): 

57 self.envars: dict[str, str | None] = {} 

58 

59 def set(self, name: str, value: str) -> None: 

60 self.envars[name] = os.getenv(name) 

61 os.environ[name] = value 

62 

63 def remove(self, name: str) -> None: 

64 self.envars[name] = os.environ.pop(name, None) 

65 

66 def reset(self) -> None: # pragma: no cover 

67 for name, value in self.envars.items(): 

68 if value is None: 

69 os.environ.pop(name, None) 

70 else: 

71 os.environ[name] = value 

72 

73 

74@pytest.fixture 

75def env() -> Iterator[TestEnv]: 

76 test_env = TestEnv() 

77 

78 yield test_env 

79 

80 test_env.reset() 

81 

82 

83@pytest.fixture 

84def anyio_backend(): 

85 return 'asyncio' 

86 

87 

88@pytest.fixture 

89def allow_model_requests(): 

90 with pydantic_ai.models.override_allow_model_requests(True): 

91 yield 

92 

93 

94@pytest.fixture 

95async def client_with_handler() -> AsyncIterator[ClientWithHandler]: 

96 client: httpx.AsyncClient | None = None 

97 

98 def create_client(handler: Callable[[httpx.Request], httpx.Response]) -> httpx.AsyncClient: 

99 nonlocal client 

100 assert client is None, 'client_with_handler can only be called once' 

101 client = httpx.AsyncClient(mounts={'all://': httpx.MockTransport(handler)}) 

102 return client 

103 

104 try: 

105 yield create_client 

106 finally: 

107 if client: # pragma: no cover 

108 await client.aclose() 

109 

110 

111ClientWithHandler: TypeAlias = Callable[[Callable[[httpx.Request], httpx.Response]], httpx.AsyncClient] 

112 

113 

114# pyright: reportUnknownMemberType=false, reportUnknownArgumentType=false 

115@pytest.fixture 

116def create_module(tmp_path: Path, request: pytest.FixtureRequest) -> Callable[[str], Any]: 

117 """Taken from `pydantic/tests/conftest.py`, create module object, execute and return it.""" 

118 

119 def run( 

120 source_code: str, 

121 rewrite_assertions: bool = True, 

122 module_name_prefix: str | None = None, 

123 ) -> ModuleType: 

124 """Create module object, execute and return it. 

125 

126 Can be used as a decorator of the function from the source code of which the module will be constructed. 

127 

128 Args: 

129 source_code: Python source code of the module 

130 rewrite_assertions: whether to rewrite assertions in module or not 

131 module_name_prefix: string prefix to use in the name of the module, does not affect the name of the file. 

132 

133 """ 

134 

135 # Max path length in Windows is 260. Leaving some buffer here 

136 max_name_len = 240 - len(str(tmp_path)) 

137 # Windows does not allow these characters in paths. Linux bans slashes only. 

138 sanitized_name = re.sub('[' + re.escape('<>:"/\\|?*') + ']', '-', request.node.name)[:max_name_len] 

139 module_name = f'{sanitized_name}_{secrets.token_hex(5)}' 

140 path = tmp_path / f'{module_name}.py' 

141 path.write_text(source_code) 

142 filename = str(path) 

143 

144 if module_name_prefix: # pragma: no cover 

145 module_name = module_name_prefix + module_name 

146 

147 if rewrite_assertions: 

148 loader = AssertionRewritingHook(config=request.config) 

149 loader.mark_rewrite(module_name) 

150 else: # pragma: no cover 

151 loader = None 

152 

153 spec = importlib.util.spec_from_file_location(module_name, filename, loader=loader) 

154 sys.modules[module_name] = module = importlib.util.module_from_spec(spec) # pyright: ignore[reportArgumentType] 

155 spec.loader.exec_module(module) # pyright: ignore[reportOptionalMemberAccess] 

156 return module 

157 

158 return run 

159 

160 

161@contextmanager 

162def try_import() -> Iterator[Callable[[], bool]]: 

163 import_success = False 

164 

165 def check_import() -> bool: 

166 return import_success 

167 

168 try: 

169 yield check_import 

170 except ImportError: 

171 pass 

172 else: 

173 import_success = True 

174 

175 

176@pytest.fixture(autouse=True) 

177def set_event_loop() -> Iterator[None]: 

178 new_loop = asyncio.new_event_loop() 

179 asyncio.set_event_loop(new_loop) 

180 yield 

181 new_loop.close()