Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use core loop for reverse proxy async IO operations #675

Merged
merged 5 commits into from
Nov 5, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions proxy/common/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
DEFAULT_TIMEOUT = 10
DEFAULT_VERSION = False
DEFAULT_HTTP_PORT = 80
DEFAULT_HTTPS_PORT = 443
DEFAULT_MAX_SEND_SIZE = 16 * 1024

DEFAULT_DATA_DIRECTORY_PATH = os.path.join(str(pathlib.Path.home()), '.proxy')
Expand Down
2 changes: 2 additions & 0 deletions proxy/core/acceptor/threadless.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,8 @@ def run(self) -> None:
self.loop = asyncio.get_event_loop()
while not self.running.is_set():
self.run_once()
except KeyboardInterrupt:
pass
finally:
assert self.selector is not None
self.selector.unregister(self.client_queue)
Expand Down
7 changes: 5 additions & 2 deletions proxy/dashboard/dashboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,10 @@ def handle_request(self, request: HttpParser) -> None:
if request.path == b'/dashboard/':
self.client.queue(
HttpWebServerPlugin.read_and_build_static_file_response(
os.path.join(self.flags.static_server_dir, 'dashboard', 'proxy.html'),
os.path.join(
self.flags.static_server_dir,
'dashboard', 'proxy.html',
),
),
)
elif request.path in (
Expand Down Expand Up @@ -105,7 +108,7 @@ def on_websocket_message(self, frame: WebsocketFrame) -> None:
logger.info(frame.opcode)
self.reply({'id': message['id'], 'response': 'not_implemented'})

def on_websocket_close(self) -> None:
def on_client_connection_close(self) -> None:
logger.info('app ws closed')
# TODO(abhinavsingh): unsubscribe

Expand Down
6 changes: 4 additions & 2 deletions proxy/http/inspector/devtools.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@ def handle_request(self, request: HttpParser) -> None:

def on_websocket_open(self) -> None:
self.subscriber.subscribe(
lambda event: CoreEventsToDevtoolsProtocol.transformer(self.client, event),
lambda event: CoreEventsToDevtoolsProtocol.transformer(
self.client, event,
),
)

def on_websocket_message(self, frame: WebsocketFrame) -> None:
Expand All @@ -73,7 +75,7 @@ def on_websocket_message(self, frame: WebsocketFrame) -> None:
return
self.handle_devtools_message(message)

def on_websocket_close(self) -> None:
def on_client_connection_close(self) -> None:
self.subscriber.unsubscribe()

def handle_devtools_message(self, message: Dict[str, Any]) -> None:
Expand Down
2 changes: 1 addition & 1 deletion proxy/http/server/pac_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def on_websocket_open(self) -> None:
def on_websocket_message(self, frame: WebsocketFrame) -> None:
pass # pragma: no cover

def on_websocket_close(self) -> None:
def on_client_connection_close(self) -> None:
pass # pragma: no cover

def cache_pac_file_response(self) -> None:
Expand Down
26 changes: 22 additions & 4 deletions proxy/http/server/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,13 @@ def __init__(
self.client = client
self.event_queue = event_queue

def name(self) -> str:
"""A unique name for your plugin.

Defaults to name of the class. This helps plugin developers to directly
access a specific plugin by its name."""
return self.__class__.__name__ # pragma: no cover

# TODO(abhinavsingh): get_descriptors, write_to_descriptors, read_from_descriptors
# can be placed into their own abstract class which can then be shared by
# HttpProxyBasePlugin, HttpWebServerBasePlugin and HttpProtocolHandlerPlugin class.
Expand Down Expand Up @@ -79,6 +86,10 @@ def handle_request(self, request: HttpParser) -> None:
"""Handle the request and serve response."""
raise NotImplementedError() # pragma: no cover

def on_client_connection_close(self) -> None:
"""Client has closed the connection, do any clean up task now."""
pass

@abstractmethod
def on_websocket_open(self) -> None:
"""Called when websocket handshake has finished."""
Expand All @@ -89,7 +100,14 @@ def on_websocket_message(self, frame: WebsocketFrame) -> None:
"""Handle websocket frame."""
raise NotImplementedError() # pragma: no cover

@abstractmethod
def on_websocket_close(self) -> None:
"""Called when websocket connection has been closed."""
raise NotImplementedError() # pragma: no cover
# Deprecated since v2.4.0
#
# Instead use on_client_connection_close.
#
# This callback is no longer invoked. Kindly
# update your plugin before upgrading to v2.4.0.
#
# @abstractmethod
# def on_websocket_close(self) -> None:
# """Called when websocket connection has been closed."""
# raise NotImplementedError() # pragma: no cover
26 changes: 20 additions & 6 deletions proxy/http/server/web.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def __init__(
}
self.route: Optional[HttpWebServerBasePlugin] = None

self.plugins: Dict[str, HttpWebServerBasePlugin] = {}
if b'HttpWebServerBasePlugin' in self.flags.plugins:
for klass in self.flags.plugins[b'HttpWebServerBasePlugin']:
instance: HttpWebServerBasePlugin = klass(
Expand All @@ -91,6 +92,7 @@ def __init__(
self.client,
self.event_queue,
)
self.plugins[instance.name()] = instance
for (protocol, route) in instance.routes():
self.routes[protocol][re.compile(route)] = instance

Expand Down Expand Up @@ -201,16 +203,28 @@ def on_request_complete(self) -> Union[socket.socket, bool]:
self.client.queue(self.DEFAULT_404_RESPONSE)
return True

# TODO(abhinavsingh): Call plugin get/read/write descriptor callbacks
def get_descriptors(
self,
) -> Tuple[List[socket.socket], List[socket.socket]]:
return [], []
r, w = [], []
for plugin in self.plugins.values():
r1, w1 = plugin.get_descriptors()
r.extend(r1)
w.extend(w1)
return r, w

def write_to_descriptors(self, w: Writables) -> bool:
for plugin in self.plugins.values():
teardown = plugin.write_to_descriptors(w)
if teardown:
return True
return False

def read_from_descriptors(self, r: Readables) -> bool:
for plugin in self.plugins.values():
teardown = plugin.read_from_descriptors(r)
if teardown:
return True
return False

def on_client_data(self, raw: memoryview) -> Optional[memoryview]:
Expand Down Expand Up @@ -260,12 +274,12 @@ def on_response_chunk(self, chunk: List[memoryview]) -> List[memoryview]:
def on_client_connection_close(self) -> None:
if self.request.has_host():
return
if self.switched_protocol:
# Invoke plugin.on_websocket_close
assert self.route
self.route.on_websocket_close()
if self.route:
self.route.on_client_connection_close()
self.access_log()

# TODO: Allow plugins to customize access_log, similar
# to how proxy server plugins are able to do it.
def access_log(self) -> None:
logger.info(
'%s:%s - %s %s - %.2f ms' %
Expand Down
110 changes: 98 additions & 12 deletions proxy/plugin/reverse_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,36 @@
:copyright: (c) 2013-present by Abhinav Singh and contributors.
:license: BSD, see LICENSE for more details.
"""
import ssl
import random
from typing import List, Tuple
import socket
import logging
import sysconfig

from pathlib import Path
from typing import List, Optional, Tuple, Any
from urllib import parse as urlparse

from ..common.constants import DEFAULT_BUFFER_SIZE, DEFAULT_HTTP_PORT
from ..common.utils import socket_connection, text_
from ..common.utils import text_
from ..common.constants import DEFAULT_HTTPS_PORT, DEFAULT_HTTP_PORT
from ..common.types import Readables, Writables
from ..core.connection import TcpServerConnection
from ..http.exception import HttpProtocolException
from ..http.parser import HttpParser
from ..http.websocket import WebsocketFrame
from ..http.server import HttpWebServerBasePlugin, httpProtocolTypes

logger = logging.getLogger(__name__)

# We need CA bundle to verify TLS connection to upstream servers
PURE_LIB = sysconfig.get_path('purelib')
assert PURE_LIB
CACERT_PEM_PATH = Path(PURE_LIB) / 'certifi' / 'cacert.pem'


# TODO: ReverseProxyPlugin and ProxyPoolPlugin are implementing
# a similar behavior. Abstract that particular logic out into its
# own class.
class ReverseProxyPlugin(HttpWebServerBasePlugin):
"""Extend in-built Web Server to add Reverse Proxy capabilities.

Expand All @@ -39,35 +58,102 @@ class ReverseProxyPlugin(HttpWebServerBasePlugin):
"User-Agent": "curl/7.64.1"
},
"origin": "1.2.3.4, 5.6.7.8",
"url": "https://localhost/get"
"url": "http://localhost/get"
}
"""

# TODO: We must use nginx python parser and
# make this plugin nginx.conf complaint.
REVERSE_PROXY_LOCATION: str = r'/get$'
# Randomly choose either http or https upstream endpoint.
#
# This is just to demonstrate that both http and https upstream
# reverse proxy works.
REVERSE_PROXY_PASS = [
b'http://httpbin.org/get',
b'https://httpbin.org/get',
]

def __init__(self, *args: Any, **kwargs: Any):
super().__init__(*args, **kwargs)
self.upstream: Optional[TcpServerConnection] = None

def routes(self) -> List[Tuple[int, str]]:
return [
(httpProtocolTypes.HTTP, ReverseProxyPlugin.REVERSE_PROXY_LOCATION),
(httpProtocolTypes.HTTPS, ReverseProxyPlugin.REVERSE_PROXY_LOCATION),
]

# TODO(abhinavsingh): Upgrade to use non-blocking get/read/write API.
def get_descriptors(self) -> Tuple[List[socket.socket], List[socket.socket]]:
if not self.upstream:
return [], []
return [self.upstream.connection], [self.upstream.connection] if self.upstream.has_buffer() else []

def read_from_descriptors(self, r: Readables) -> bool:
if self.upstream and self.upstream.connection in r:
try:
raw = self.upstream.recv(self.flags.server_recvbuf_size)
if raw is not None:
self.client.queue(raw)
else:
return True # Teardown because upstream server closed the connection
except ssl.SSLWantReadError:
logger.info('Upstream server SSLWantReadError, will retry')
return False
except ConnectionResetError:
logger.debug('Connection reset by upstream server')
return True
return super().read_from_descriptors(r)

def write_to_descriptors(self, w: Writables) -> bool:
if self.upstream and self.upstream.connection in w and self.upstream.has_buffer():
try:
self.upstream.flush()
except ssl.SSLWantWriteError:
logger.info('Upstream server SSLWantWriteError, will retry')
return False
except BrokenPipeError:
logger.debug(
'BrokenPipeError when flushing to upstream server',
)
return True
return super().write_to_descriptors(w)

def handle_request(self, request: HttpParser) -> None:
upstream = random.choice(ReverseProxyPlugin.REVERSE_PROXY_PASS)
url = urlparse.urlsplit(upstream)
url = urlparse.urlsplit(
random.choice(ReverseProxyPlugin.REVERSE_PROXY_PASS),
)
assert url.hostname
with socket_connection((text_(url.hostname), url.port if url.port else DEFAULT_HTTP_PORT)) as conn:
conn.send(request.build())
self.client.queue(memoryview(conn.recv(DEFAULT_BUFFER_SIZE)))
port = url.port or (
DEFAULT_HTTP_PORT if url.scheme ==
b'http' else DEFAULT_HTTPS_PORT
)
self.upstream = TcpServerConnection(text_(url.hostname), port)
try:
self.upstream.connect()
if url.scheme == b'https':
self.upstream.wrap(
text_(
url.hostname,
), ca_file=str(CACERT_PEM_PATH),
)
self.upstream.queue(memoryview(request.build()))
except ConnectionRefusedError:
logger.info(
'Connection refused by upstream server {0}:{1}'.format(
text_(url.hostname), port,
),
)
raise HttpProtocolException()

def on_websocket_open(self) -> None:
pass

def on_websocket_message(self, frame: WebsocketFrame) -> None:
pass

def on_websocket_close(self) -> None:
pass
def on_client_connection_close(self) -> None:
if self.upstream and not self.upstream.closed:
logger.debug('Closing upstream server connection')
self.upstream.close()
self.upstream = None
2 changes: 1 addition & 1 deletion proxy/plugin/web_server_route.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,5 +54,5 @@ def on_websocket_open(self) -> None:
def on_websocket_message(self, frame: WebsocketFrame) -> None:
logger.info(frame.data)

def on_websocket_close(self) -> None:
def on_client_connection_close(self) -> None:
logger.info('Websocket close')