diff --git a/proxy/http/server/web.py b/proxy/http/server/web.py index fafc89e58a..5892c57171 100644 --- a/proxy/http/server/web.py +++ b/proxy/http/server/web.py @@ -14,7 +14,7 @@ import logging import mimetypes -from typing import List, Optional, Dict, Union, Any, Pattern +from typing import List, Optional, Dict, Tuple, Union, Any, Pattern from ...common.constants import DEFAULT_STATIC_SERVER_DIR from ...common.constants import DEFAULT_ENABLE_STATIC_SERVER, DEFAULT_ENABLE_WEB_SERVER @@ -28,7 +28,7 @@ from ..websocket import WebsocketFrame, websocketOpcodes from ..parser import HttpParser, httpParserTypes from ..protocols import httpProtocols -from ..responses import NOT_FOUND_RESPONSE_PKT, NOT_IMPLEMENTED_RESPONSE_PKT, okResponse +from ..responses import NOT_FOUND_RESPONSE_PKT, okResponse from .plugin import HttpWebServerBasePlugin from .protocols import httpProtocolTypes @@ -138,65 +138,28 @@ def read_and_build_static_file_response(path: str) -> memoryview: except FileNotFoundError: return NOT_FOUND_RESPONSE_PKT - def try_upgrade(self) -> bool: - if self.request.has_header(b'connection') and \ - self.request.header(b'connection').lower() == b'upgrade': - if self.request.has_header(b'upgrade') and \ - self.request.header(b'upgrade').lower() == b'websocket': - self.client.queue( - memoryview( - build_websocket_handshake_response( - WebsocketFrame.key_to_accept( - self.request.header(b'Sec-WebSocket-Key'), - ), - ), + def switch_to_websocket(self) -> None: + self.client.queue( + memoryview( + build_websocket_handshake_response( + WebsocketFrame.key_to_accept( + self.request.header(b'Sec-WebSocket-Key'), ), - ) - self.switched_protocol = httpProtocolTypes.WEBSOCKET - else: - self.client.queue(NOT_IMPLEMENTED_RESPONSE_PKT) - return True - return False + ), + ), + ) + self.switched_protocol = httpProtocolTypes.WEBSOCKET def on_request_complete(self) -> Union[socket.socket, bool]: path = self.request.path or b'/' - # Routing for Http(s) requests - protocol = httpProtocolTypes.HTTPS \ - if self.encryption_enabled() else \ - httpProtocolTypes.HTTP - for route in self.routes[protocol]: - if route.match(text_(path)): - self.route = self.routes[protocol][route] - assert self.route - self.route.handle_request(self.request) - if self.request.has_header(b'connection') and \ - self.request.header(b'connection').lower() == b'close': - return True - return False - # If a websocket route exists for the path, try upgrade - for route in self.routes[httpProtocolTypes.WEBSOCKET]: - if route.match(text_(path)): - self.route = self.routes[httpProtocolTypes.WEBSOCKET][route] - # Connection upgrade - teardown = self.try_upgrade() - if teardown: - return True - # For upgraded connections, nothing more to do - if self.switched_protocol: - # Invoke plugin.on_websocket_open - assert self.route - self.route.on_websocket_open() - return False - break + # Try route + teardown = self._try_route(path) + if teardown: + return teardown # No-route found, try static serving if enabled - if self.flags.enable_static_server: - path = text_(path).split('?', 1)[0] - self.client.queue( - self.read_and_build_static_file_response( - self.flags.static_server_dir + path, - ), - ) - return True + teardown = self._try_static_file(path) + if teardown: + return teardown # Catch all unhandled web server requests, return 404 self.client.queue(NOT_FOUND_RESPONSE_PKT) return True @@ -305,3 +268,44 @@ def on_client_connection_close(self) -> None: def access_log(self, context: Dict[str, Any]) -> None: logger.info(DEFAULT_WEB_ACCESS_LOG_FORMAT.format_map(context)) + + @property + def _protocol(self) -> Tuple[bool, int]: + do_ws_upgrade = self.request.is_connection_upgrade and \ + self.request.header(b'upgrade').lower() == b'websocket' + return do_ws_upgrade, httpProtocolTypes.WEBSOCKET \ + if do_ws_upgrade \ + else httpProtocolTypes.HTTPS \ + if self.encryption_enabled() \ + else httpProtocolTypes.HTTP + + def _try_route(self, path: bytes) -> bool: + do_ws_upgrade, protocol = self._protocol + for route in self.routes[protocol]: + if route.match(text_(path)): + self.route = self.routes[protocol][route] + assert self.route + # Optionally, upgrade protocol + if do_ws_upgrade: + self.switch_to_websocket() + assert self.route + # Invoke plugin.on_websocket_open + self.route.on_websocket_open() + else: + # Invoke plugin.handle_request + self.route.handle_request(self.request) + if self.request.has_header(b'connection') and \ + self.request.header(b'connection').lower() == b'close': + return True + return False + + def _try_static_file(self, path: bytes) -> bool: + if self.flags.enable_static_server: + path = text_(path).split('?', 1)[0] + self.client.queue( + self.read_and_build_static_file_response( + self.flags.static_server_dir + path, + ), + ) + return True + return False