Skip to content

Commit

Permalink
[WebServer] Refactor routing to allow same path for websocket and web…
Browse files Browse the repository at this point in the history
… requests (#962)

* Switch to WS

* Refactor
  • Loading branch information
abhinavsingh authored Jan 11, 2022
1 parent 474cce1 commit a84abab
Showing 1 changed file with 60 additions and 56 deletions.
116 changes: 60 additions & 56 deletions proxy/http/server/web.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import logging
import mimetypes

from typing import List, Optional, Dict, Union, Any, Pattern
from typing import List, Optional, Dict, Tuple, Union, Any, Pattern

from ...common.constants import DEFAULT_STATIC_SERVER_DIR
from ...common.constants import DEFAULT_ENABLE_STATIC_SERVER, DEFAULT_ENABLE_WEB_SERVER
Expand All @@ -28,7 +28,7 @@
from ..websocket import WebsocketFrame, websocketOpcodes
from ..parser import HttpParser, httpParserTypes
from ..protocols import httpProtocols
from ..responses import NOT_FOUND_RESPONSE_PKT, NOT_IMPLEMENTED_RESPONSE_PKT, okResponse
from ..responses import NOT_FOUND_RESPONSE_PKT, okResponse

from .plugin import HttpWebServerBasePlugin
from .protocols import httpProtocolTypes
Expand Down Expand Up @@ -138,65 +138,28 @@ def read_and_build_static_file_response(path: str) -> memoryview:
except FileNotFoundError:
return NOT_FOUND_RESPONSE_PKT

def try_upgrade(self) -> bool:
if self.request.has_header(b'connection') and \
self.request.header(b'connection').lower() == b'upgrade':
if self.request.has_header(b'upgrade') and \
self.request.header(b'upgrade').lower() == b'websocket':
self.client.queue(
memoryview(
build_websocket_handshake_response(
WebsocketFrame.key_to_accept(
self.request.header(b'Sec-WebSocket-Key'),
),
),
def switch_to_websocket(self) -> None:
self.client.queue(
memoryview(
build_websocket_handshake_response(
WebsocketFrame.key_to_accept(
self.request.header(b'Sec-WebSocket-Key'),
),
)
self.switched_protocol = httpProtocolTypes.WEBSOCKET
else:
self.client.queue(NOT_IMPLEMENTED_RESPONSE_PKT)
return True
return False
),
),
)
self.switched_protocol = httpProtocolTypes.WEBSOCKET

def on_request_complete(self) -> Union[socket.socket, bool]:
path = self.request.path or b'/'
# Routing for Http(s) requests
protocol = httpProtocolTypes.HTTPS \
if self.encryption_enabled() else \
httpProtocolTypes.HTTP
for route in self.routes[protocol]:
if route.match(text_(path)):
self.route = self.routes[protocol][route]
assert self.route
self.route.handle_request(self.request)
if self.request.has_header(b'connection') and \
self.request.header(b'connection').lower() == b'close':
return True
return False
# If a websocket route exists for the path, try upgrade
for route in self.routes[httpProtocolTypes.WEBSOCKET]:
if route.match(text_(path)):
self.route = self.routes[httpProtocolTypes.WEBSOCKET][route]
# Connection upgrade
teardown = self.try_upgrade()
if teardown:
return True
# For upgraded connections, nothing more to do
if self.switched_protocol:
# Invoke plugin.on_websocket_open
assert self.route
self.route.on_websocket_open()
return False
break
# Try route
teardown = self._try_route(path)
if teardown:
return teardown
# No-route found, try static serving if enabled
if self.flags.enable_static_server:
path = text_(path).split('?', 1)[0]
self.client.queue(
self.read_and_build_static_file_response(
self.flags.static_server_dir + path,
),
)
return True
teardown = self._try_static_file(path)
if teardown:
return teardown
# Catch all unhandled web server requests, return 404
self.client.queue(NOT_FOUND_RESPONSE_PKT)
return True
Expand Down Expand Up @@ -305,3 +268,44 @@ def on_client_connection_close(self) -> None:

def access_log(self, context: Dict[str, Any]) -> None:
logger.info(DEFAULT_WEB_ACCESS_LOG_FORMAT.format_map(context))

@property
def _protocol(self) -> Tuple[bool, int]:
do_ws_upgrade = self.request.is_connection_upgrade and \
self.request.header(b'upgrade').lower() == b'websocket'
return do_ws_upgrade, httpProtocolTypes.WEBSOCKET \
if do_ws_upgrade \
else httpProtocolTypes.HTTPS \
if self.encryption_enabled() \
else httpProtocolTypes.HTTP

def _try_route(self, path: bytes) -> bool:
do_ws_upgrade, protocol = self._protocol
for route in self.routes[protocol]:
if route.match(text_(path)):
self.route = self.routes[protocol][route]
assert self.route
# Optionally, upgrade protocol
if do_ws_upgrade:
self.switch_to_websocket()
assert self.route
# Invoke plugin.on_websocket_open
self.route.on_websocket_open()
else:
# Invoke plugin.handle_request
self.route.handle_request(self.request)
if self.request.has_header(b'connection') and \
self.request.header(b'connection').lower() == b'close':
return True
return False

def _try_static_file(self, path: bytes) -> bool:
if self.flags.enable_static_server:
path = text_(path).split('?', 1)[0]
self.client.queue(
self.read_and_build_static_file_response(
self.flags.static_server_dir + path,
),
)
return True
return False

0 comments on commit a84abab

Please sign in to comment.