Coverage for tests/test_ws_router.py: 100%

157 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-12-04 08:29 +0000

1import functools 1abcdefg

2 

3import pytest 1abcdefg

4from fastapi import ( 1abcdefg

5 APIRouter, 

6 Depends, 

7 FastAPI, 

8 Header, 

9 WebSocket, 

10 WebSocketDisconnect, 

11 status, 

12) 

13from fastapi.middleware import Middleware 1abcdefg

14from fastapi.testclient import TestClient 1abcdefg

15 

16router = APIRouter() 1abcdefg

17prefix_router = APIRouter() 1abcdefg

18native_prefix_route = APIRouter(prefix="/native") 1abcdefg

19app = FastAPI() 1abcdefg

20 

21 

22@app.websocket_route("/") 1abcdefg

23async def index(websocket: WebSocket): 1abcdefg

24 await websocket.accept() 1JKLMNOP

25 await websocket.send_text("Hello, world!") 1JKLMNOP

26 await websocket.close() 1JKLMNOP

27 

28 

29@router.websocket_route("/router") 1abcdefg

30async def routerindex(websocket: WebSocket): 1abcdefg

31 await websocket.accept() 1QRSTUVW

32 await websocket.send_text("Hello, router!") 1QRSTUVW

33 await websocket.close() 1QRSTUVW

34 

35 

36@prefix_router.websocket_route("/") 1abcdefg

37async def routerprefixindex(websocket: WebSocket): 1abcdefg

38 await websocket.accept() 1XYZ0123

39 await websocket.send_text("Hello, router with prefix!") 1XYZ0123

40 await websocket.close() 1XYZ0123

41 

42 

43@router.websocket("/router2") 1abcdefg

44async def routerindex2(websocket: WebSocket): 1abcdefg

45 await websocket.accept() 1456789!

46 await websocket.send_text("Hello, router!") 1456789!

47 await websocket.close() 1456789!

48 

49 

50@router.websocket("/router/{pathparam:path}") 1abcdefg

51async def routerindexparams(websocket: WebSocket, pathparam: str, queryparam: str): 1abcdefg

52 await websocket.accept() 1CDEFGHI

53 await websocket.send_text(pathparam) 1CDEFGHI

54 await websocket.send_text(queryparam) 1CDEFGHI

55 await websocket.close() 1CDEFGHI

56 

57 

58async def ws_dependency(): 1abcdefg

59 return "Socket Dependency" 1#$%'()*

60 

61 

62@router.websocket("/router-ws-depends/") 1abcdefg

63async def router_ws_decorator_depends( 1abcdefg

64 websocket: WebSocket, data=Depends(ws_dependency) 

65): 

66 await websocket.accept() 1#+$,%-'.(/):*;

67 await websocket.send_text(data) 1#+$,%-'.(/):*;

68 await websocket.close() 1#+$,%-'.(/):*;

69 

70 

71@native_prefix_route.websocket("/") 1abcdefg

72async def router_native_prefix_ws(websocket: WebSocket): 1abcdefg

73 await websocket.accept() 1=?@[]^_

74 await websocket.send_text("Hello, router with native prefix!") 1=?@[]^_

75 await websocket.close() 1=?@[]^_

76 

77 

78async def ws_dependency_err(): 1abcdefg

79 raise NotImplementedError() 1ihjklmn

80 

81 

82@router.websocket("/depends-err/") 1abcdefg

83async def router_ws_depends_err(websocket: WebSocket, data=Depends(ws_dependency_err)): 1abcdefg

84 pass # pragma: no cover 

85 

86 

87async def ws_dependency_validate(x_missing: str = Header()): 1abcdefg

88 pass # pragma: no cover 

89 

90 

91@router.websocket("/depends-validate/") 1abcdefg

92async def router_ws_depends_validate( 1abcdefg

93 websocket: WebSocket, data=Depends(ws_dependency_validate) 

94): 

95 pass # pragma: no cover 

96 

97 

98class CustomError(Exception): 1abcdefg

99 pass 1abcdefg

100 

101 

102@router.websocket("/custom_error/") 1abcdefg

103async def router_ws_custom_error(websocket: WebSocket): 1abcdefg

104 raise CustomError() 1wvxyzAB

105 

106 

107def make_app(app=None, **kwargs): 1abcdefg

108 app = app or FastAPI(**kwargs) 1awipbvhocxjqdykrezlsfAmtgBnu

109 app.include_router(router) 1awipbvhocxjqdykrezlsfAmtgBnu

110 app.include_router(prefix_router, prefix="/prefix") 1awipbvhocxjqdykrezlsfAmtgBnu

111 app.include_router(native_prefix_route) 1awipbvhocxjqdykrezlsfAmtgBnu

112 return app 1awipbvhocxjqdykrezlsfAmtgBnu

113 

114 

115app = make_app(app) 1abcdefg

116 

117 

118def test_app(): 1abcdefg

119 client = TestClient(app) 1JKLMNOP

120 with client.websocket_connect("/") as websocket: 1JKLMNOP

121 data = websocket.receive_text() 1JKLMNOP

122 assert data == "Hello, world!" 1JKLMNOP

123 

124 

125def test_router(): 1abcdefg

126 client = TestClient(app) 1QRSTUVW

127 with client.websocket_connect("/router") as websocket: 1QRSTUVW

128 data = websocket.receive_text() 1QRSTUVW

129 assert data == "Hello, router!" 1QRSTUVW

130 

131 

132def test_prefix_router(): 1abcdefg

133 client = TestClient(app) 1XYZ0123

134 with client.websocket_connect("/prefix/") as websocket: 1XYZ0123

135 data = websocket.receive_text() 1XYZ0123

136 assert data == "Hello, router with prefix!" 1XYZ0123

137 

138 

139def test_native_prefix_router(): 1abcdefg

140 client = TestClient(app) 1=?@[]^_

141 with client.websocket_connect("/native/") as websocket: 1=?@[]^_

142 data = websocket.receive_text() 1=?@[]^_

143 assert data == "Hello, router with native prefix!" 1=?@[]^_

144 

145 

146def test_router2(): 1abcdefg

147 client = TestClient(app) 1456789!

148 with client.websocket_connect("/router2") as websocket: 1456789!

149 data = websocket.receive_text() 1456789!

150 assert data == "Hello, router!" 1456789!

151 

152 

153def test_router_ws_depends(): 1abcdefg

154 client = TestClient(app) 1#$%'()*

155 with client.websocket_connect("/router-ws-depends/") as websocket: 1#$%'()*

156 assert websocket.receive_text() == "Socket Dependency" 1#$%'()*

157 

158 

159def test_router_ws_depends_with_override(): 1abcdefg

160 client = TestClient(app) 1+,-./:;

161 app.dependency_overrides[ws_dependency] = lambda: "Override" # noqa: E731 1+,-./:;

162 with client.websocket_connect("/router-ws-depends/") as websocket: 1+,-./:;

163 assert websocket.receive_text() == "Override" 1+,-./:;

164 

165 

166def test_router_with_params(): 1abcdefg

167 client = TestClient(app) 1CDEFGHI

168 with client.websocket_connect( 1CDEFGHI

169 "/router/path/to/file?queryparam=a_query_param" 

170 ) as websocket: 

171 data = websocket.receive_text() 1CDEFGHI

172 assert data == "path/to/file" 1CDEFGHI

173 data = websocket.receive_text() 1CDEFGHI

174 assert data == "a_query_param" 1CDEFGHI

175 

176 

177def test_wrong_uri(): 1abcdefg

178 """ 

179 Verify that a websocket connection to a non-existent endpoing returns in a shutdown 

180 """ 

181 client = TestClient(app) 2{ ` | } ~ abbb

182 with pytest.raises(WebSocketDisconnect) as e: 2{ ` | } ~ abbb

183 with client.websocket_connect("/no-router/"): 2{ ` | } ~ abbb

184 pass # pragma: no cover 1`

185 assert e.value.code == status.WS_1000_NORMAL_CLOSURE 2{ ` | } ~ abbb

186 

187 

188def websocket_middleware(middleware_func): 1abcdefg

189 """ 

190 Helper to create a Starlette pure websocket middleware 

191 """ 

192 

193 def middleware_constructor(app): 1iphojqkrlsmtnu

194 @functools.wraps(app) 1iphojqkrlsmtnu

195 async def wrapped_app(scope, receive, send): 1iphojqkrlsmtnu

196 if scope["type"] != "websocket": 1iphojqkrlsmtnu

197 return await app(scope, receive, send) # pragma: no cover 

198 

199 async def call_next(): 1iphojqkrlsmtnu

200 return await app(scope, receive, send) 1iphojqkrlsmtnu

201 

202 websocket = WebSocket(scope, receive=receive, send=send) 1iphojqkrlsmtnu

203 return await middleware_func(websocket, call_next) 1iphojqkrlsmtnu

204 

205 return wrapped_app 1iphojqkrlsmtnu

206 

207 return middleware_constructor 1iphojqkrlsmtnu

208 

209 

210def test_depend_validation(): 1abcdefg

211 """ 

212 Verify that a validation in a dependency invokes the correct exception handler 

213 """ 

214 caught = [] 1poqrstu

215 

216 @websocket_middleware 1poqrstu

217 async def catcher(websocket, call_next): 1poqrstu

218 try: 1poqrstu

219 return await call_next() 1poqrstu

220 except Exception as e: # pragma: no cover 

221 caught.append(e) 

222 raise 

223 

224 myapp = make_app(middleware=[Middleware(catcher)]) 1poqrstu

225 

226 client = TestClient(myapp) 1poqrstu

227 with pytest.raises(WebSocketDisconnect) as e: 1poqrstu

228 with client.websocket_connect("/depends-validate/"): 1poqrstu

229 pass # pragma: no cover 1o

230 # the validation error does produce a close message 

231 assert e.value.code == status.WS_1008_POLICY_VIOLATION 1poqrstu

232 # and no error is leaked 

233 assert caught == [] 1poqrstu

234 

235 

236def test_depend_err_middleware(): 1abcdefg

237 """ 

238 Verify that it is possible to write custom WebSocket middleware to catch errors 

239 """ 

240 

241 @websocket_middleware 1ihjklmn

242 async def errorhandler(websocket: WebSocket, call_next): 1ihjklmn

243 try: 1ihjklmn

244 return await call_next() 1ihjklmn

245 except Exception as e: 1ihjklmn

246 await websocket.close(code=status.WS_1006_ABNORMAL_CLOSURE, reason=repr(e)) 1ihjklmn

247 

248 myapp = make_app(middleware=[Middleware(errorhandler)]) 1ihjklmn

249 client = TestClient(myapp) 1ihjklmn

250 with pytest.raises(WebSocketDisconnect) as e: 1ihjklmn

251 with client.websocket_connect("/depends-err/"): 1ihjklmn

252 pass # pragma: no cover 1h

253 assert e.value.code == status.WS_1006_ABNORMAL_CLOSURE 1ihjklmn

254 assert "NotImplementedError" in e.value.reason 1ihjklmn

255 

256 

257def test_depend_err_handler(): 1abcdefg

258 """ 

259 Verify that it is possible to write custom WebSocket middleware to catch errors 

260 """ 

261 

262 async def custom_handler(websocket: WebSocket, exc: CustomError) -> None: 1wvxyzAB

263 await websocket.close(1002, "foo") 1wvxyzAB

264 

265 myapp = make_app(exception_handlers={CustomError: custom_handler}) 1wvxyzAB

266 client = TestClient(myapp) 1wvxyzAB

267 with pytest.raises(WebSocketDisconnect) as e: 1wvxyzAB

268 with client.websocket_connect("/custom_error/"): 1wvxyzAB

269 pass # pragma: no cover 1v

270 assert e.value.code == 1002 1wvxyzAB

271 assert "foo" in e.value.reason 1wvxyzAB