Coverage for tests/test_ws_router.py: 100%
157 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-12-04 08:29 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2025-12-04 08:29 +0000
1import functools 1abcdefg
3import pytest 1abcdefg
4from fastapi import ( 1abcdefg
5 APIRouter,
6 Depends,
7 FastAPI,
8 Header,
9 WebSocket,
10 WebSocketDisconnect,
11 status,
12)
13from fastapi.middleware import Middleware 1abcdefg
14from fastapi.testclient import TestClient 1abcdefg
16router = APIRouter() 1abcdefg
17prefix_router = APIRouter() 1abcdefg
18native_prefix_route = APIRouter(prefix="/native") 1abcdefg
19app = FastAPI() 1abcdefg
22@app.websocket_route("/") 1abcdefg
23async def index(websocket: WebSocket): 1abcdefg
24 await websocket.accept() 1JKLMNOP
25 await websocket.send_text("Hello, world!") 1JKLMNOP
26 await websocket.close() 1JKLMNOP
29@router.websocket_route("/router") 1abcdefg
30async def routerindex(websocket: WebSocket): 1abcdefg
31 await websocket.accept() 1QRSTUVW
32 await websocket.send_text("Hello, router!") 1QRSTUVW
33 await websocket.close() 1QRSTUVW
36@prefix_router.websocket_route("/") 1abcdefg
37async def routerprefixindex(websocket: WebSocket): 1abcdefg
38 await websocket.accept() 1XYZ0123
39 await websocket.send_text("Hello, router with prefix!") 1XYZ0123
40 await websocket.close() 1XYZ0123
43@router.websocket("/router2") 1abcdefg
44async def routerindex2(websocket: WebSocket): 1abcdefg
45 await websocket.accept() 1456789!
46 await websocket.send_text("Hello, router!") 1456789!
47 await websocket.close() 1456789!
50@router.websocket("/router/{pathparam:path}") 1abcdefg
51async def routerindexparams(websocket: WebSocket, pathparam: str, queryparam: str): 1abcdefg
52 await websocket.accept() 1CDEFGHI
53 await websocket.send_text(pathparam) 1CDEFGHI
54 await websocket.send_text(queryparam) 1CDEFGHI
55 await websocket.close() 1CDEFGHI
58async def ws_dependency(): 1abcdefg
59 return "Socket Dependency" 1#$%'()*
62@router.websocket("/router-ws-depends/") 1abcdefg
63async def router_ws_decorator_depends( 1abcdefg
64 websocket: WebSocket, data=Depends(ws_dependency)
65):
66 await websocket.accept() 1#+$,%-'.(/):*;
67 await websocket.send_text(data) 1#+$,%-'.(/):*;
68 await websocket.close() 1#+$,%-'.(/):*;
71@native_prefix_route.websocket("/") 1abcdefg
72async def router_native_prefix_ws(websocket: WebSocket): 1abcdefg
73 await websocket.accept() 1=?@[]^_
74 await websocket.send_text("Hello, router with native prefix!") 1=?@[]^_
75 await websocket.close() 1=?@[]^_
78async def ws_dependency_err(): 1abcdefg
79 raise NotImplementedError() 1ihjklmn
82@router.websocket("/depends-err/") 1abcdefg
83async def router_ws_depends_err(websocket: WebSocket, data=Depends(ws_dependency_err)): 1abcdefg
84 pass # pragma: no cover
87async def ws_dependency_validate(x_missing: str = Header()): 1abcdefg
88 pass # pragma: no cover
91@router.websocket("/depends-validate/") 1abcdefg
92async def router_ws_depends_validate( 1abcdefg
93 websocket: WebSocket, data=Depends(ws_dependency_validate)
94):
95 pass # pragma: no cover
98class CustomError(Exception): 1abcdefg
99 pass 1abcdefg
102@router.websocket("/custom_error/") 1abcdefg
103async def router_ws_custom_error(websocket: WebSocket): 1abcdefg
104 raise CustomError() 1wvxyzAB
107def make_app(app=None, **kwargs): 1abcdefg
108 app = app or FastAPI(**kwargs) 1awipbvhocxjqdykrezlsfAmtgBnu
109 app.include_router(router) 1awipbvhocxjqdykrezlsfAmtgBnu
110 app.include_router(prefix_router, prefix="/prefix") 1awipbvhocxjqdykrezlsfAmtgBnu
111 app.include_router(native_prefix_route) 1awipbvhocxjqdykrezlsfAmtgBnu
112 return app 1awipbvhocxjqdykrezlsfAmtgBnu
115app = make_app(app) 1abcdefg
118def test_app(): 1abcdefg
119 client = TestClient(app) 1JKLMNOP
120 with client.websocket_connect("/") as websocket: 1JKLMNOP
121 data = websocket.receive_text() 1JKLMNOP
122 assert data == "Hello, world!" 1JKLMNOP
125def test_router(): 1abcdefg
126 client = TestClient(app) 1QRSTUVW
127 with client.websocket_connect("/router") as websocket: 1QRSTUVW
128 data = websocket.receive_text() 1QRSTUVW
129 assert data == "Hello, router!" 1QRSTUVW
132def test_prefix_router(): 1abcdefg
133 client = TestClient(app) 1XYZ0123
134 with client.websocket_connect("/prefix/") as websocket: 1XYZ0123
135 data = websocket.receive_text() 1XYZ0123
136 assert data == "Hello, router with prefix!" 1XYZ0123
139def test_native_prefix_router(): 1abcdefg
140 client = TestClient(app) 1=?@[]^_
141 with client.websocket_connect("/native/") as websocket: 1=?@[]^_
142 data = websocket.receive_text() 1=?@[]^_
143 assert data == "Hello, router with native prefix!" 1=?@[]^_
146def test_router2(): 1abcdefg
147 client = TestClient(app) 1456789!
148 with client.websocket_connect("/router2") as websocket: 1456789!
149 data = websocket.receive_text() 1456789!
150 assert data == "Hello, router!" 1456789!
153def test_router_ws_depends(): 1abcdefg
154 client = TestClient(app) 1#$%'()*
155 with client.websocket_connect("/router-ws-depends/") as websocket: 1#$%'()*
156 assert websocket.receive_text() == "Socket Dependency" 1#$%'()*
159def test_router_ws_depends_with_override(): 1abcdefg
160 client = TestClient(app) 1+,-./:;
161 app.dependency_overrides[ws_dependency] = lambda: "Override" # noqa: E731 1+,-./:;
162 with client.websocket_connect("/router-ws-depends/") as websocket: 1+,-./:;
163 assert websocket.receive_text() == "Override" 1+,-./:;
166def test_router_with_params(): 1abcdefg
167 client = TestClient(app) 1CDEFGHI
168 with client.websocket_connect( 1CDEFGHI
169 "/router/path/to/file?queryparam=a_query_param"
170 ) as websocket:
171 data = websocket.receive_text() 1CDEFGHI
172 assert data == "path/to/file" 1CDEFGHI
173 data = websocket.receive_text() 1CDEFGHI
174 assert data == "a_query_param" 1CDEFGHI
177def test_wrong_uri(): 1abcdefg
178 """
179 Verify that a websocket connection to a non-existent endpoing returns in a shutdown
180 """
181 client = TestClient(app) 2{ ` | } ~ abbb
182 with pytest.raises(WebSocketDisconnect) as e: 2{ ` | } ~ abbb
183 with client.websocket_connect("/no-router/"): 2{ ` | } ~ abbb
184 pass # pragma: no cover 1`
185 assert e.value.code == status.WS_1000_NORMAL_CLOSURE 2{ ` | } ~ abbb
188def websocket_middleware(middleware_func): 1abcdefg
189 """
190 Helper to create a Starlette pure websocket middleware
191 """
193 def middleware_constructor(app): 1iphojqkrlsmtnu
194 @functools.wraps(app) 1iphojqkrlsmtnu
195 async def wrapped_app(scope, receive, send): 1iphojqkrlsmtnu
196 if scope["type"] != "websocket": 1iphojqkrlsmtnu
197 return await app(scope, receive, send) # pragma: no cover
199 async def call_next(): 1iphojqkrlsmtnu
200 return await app(scope, receive, send) 1iphojqkrlsmtnu
202 websocket = WebSocket(scope, receive=receive, send=send) 1iphojqkrlsmtnu
203 return await middleware_func(websocket, call_next) 1iphojqkrlsmtnu
205 return wrapped_app 1iphojqkrlsmtnu
207 return middleware_constructor 1iphojqkrlsmtnu
210def test_depend_validation(): 1abcdefg
211 """
212 Verify that a validation in a dependency invokes the correct exception handler
213 """
214 caught = [] 1poqrstu
216 @websocket_middleware 1poqrstu
217 async def catcher(websocket, call_next): 1poqrstu
218 try: 1poqrstu
219 return await call_next() 1poqrstu
220 except Exception as e: # pragma: no cover
221 caught.append(e)
222 raise
224 myapp = make_app(middleware=[Middleware(catcher)]) 1poqrstu
226 client = TestClient(myapp) 1poqrstu
227 with pytest.raises(WebSocketDisconnect) as e: 1poqrstu
228 with client.websocket_connect("/depends-validate/"): 1poqrstu
229 pass # pragma: no cover 1o
230 # the validation error does produce a close message
231 assert e.value.code == status.WS_1008_POLICY_VIOLATION 1poqrstu
232 # and no error is leaked
233 assert caught == [] 1poqrstu
236def test_depend_err_middleware(): 1abcdefg
237 """
238 Verify that it is possible to write custom WebSocket middleware to catch errors
239 """
241 @websocket_middleware 1ihjklmn
242 async def errorhandler(websocket: WebSocket, call_next): 1ihjklmn
243 try: 1ihjklmn
244 return await call_next() 1ihjklmn
245 except Exception as e: 1ihjklmn
246 await websocket.close(code=status.WS_1006_ABNORMAL_CLOSURE, reason=repr(e)) 1ihjklmn
248 myapp = make_app(middleware=[Middleware(errorhandler)]) 1ihjklmn
249 client = TestClient(myapp) 1ihjklmn
250 with pytest.raises(WebSocketDisconnect) as e: 1ihjklmn
251 with client.websocket_connect("/depends-err/"): 1ihjklmn
252 pass # pragma: no cover 1h
253 assert e.value.code == status.WS_1006_ABNORMAL_CLOSURE 1ihjklmn
254 assert "NotImplementedError" in e.value.reason 1ihjklmn
257def test_depend_err_handler(): 1abcdefg
258 """
259 Verify that it is possible to write custom WebSocket middleware to catch errors
260 """
262 async def custom_handler(websocket: WebSocket, exc: CustomError) -> None: 1wvxyzAB
263 await websocket.close(1002, "foo") 1wvxyzAB
265 myapp = make_app(exception_handlers={CustomError: custom_handler}) 1wvxyzAB
266 client = TestClient(myapp) 1wvxyzAB
267 with pytest.raises(WebSocketDisconnect) as e: 1wvxyzAB
268 with client.websocket_connect("/custom_error/"): 1wvxyzAB
269 pass # pragma: no cover 1v
270 assert e.value.code == 1002 1wvxyzAB
271 assert "foo" in e.value.reason 1wvxyzAB