Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Top-level notion of work not client #695

Merged
merged 2 commits into from
Nov 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/https_connect_tunnel.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def handle_data(self, data: memoryview) -> Optional[bool]:

# Drop the request if not a CONNECT request
if self.request.method != httpMethods.CONNECT:
self.client.queue(
self.work.queue(
HttpsConnectTunnelHandler.PROXY_TUNNEL_UNSUPPORTED_SCHEME,
)
return True
Expand All @@ -66,7 +66,7 @@ def handle_data(self, data: memoryview) -> Optional[bool]:
self.connect_upstream()

# Queue tunnel established response to client
self.client.queue(
self.work.queue(
HttpsConnectTunnelHandler.PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT,
)

Expand Down
8 changes: 4 additions & 4 deletions examples/ssl_echo_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,19 @@ def initialize(self) -> None:
# here using wrap_socket() utility.
assert self.flags.keyfile is not None and self.flags.certfile is not None
conn = wrap_socket(
self.client.connection,
self.work.connection,
self.flags.keyfile,
self.flags.certfile,
)
conn.setblocking(False)
# Upgrade plain TcpClientConnection to SSL connection object
self.client = TcpClientConnection(
conn=conn, addr=self.client.addr,
self.work = TcpClientConnection(
conn=conn, addr=self.work.addr,
)

def handle_data(self, data: memoryview) -> Optional[bool]:
# echo back to client
self.client.queue(data)
self.work.queue(data)
return None


Expand Down
4 changes: 2 additions & 2 deletions examples/tcp_echo_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@ class EchoServerHandler(BaseTcpServerHandler):
"""Sets client socket to non-blocking during initialization."""

def initialize(self) -> None:
self.client.connection.setblocking(False)
self.work.connection.setblocking(False)

def handle_data(self, data: memoryview) -> Optional[bool]:
# echo back to client
self.client.queue(data)
self.work.queue(data)
return None


Expand Down
9 changes: 6 additions & 3 deletions proxy/core/acceptor/work.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,18 @@ class Work(ABC):

def __init__(
self,
client: TcpClientConnection,
work: TcpClientConnection,
flags: argparse.Namespace,
event_queue: Optional[EventQueue] = None,
uid: Optional[UUID] = None,
) -> None:
self.client = client
# Work uuid
self.uid: UUID = uid if uid is not None else uuid4()
self.flags = flags
# Eventing core queue
self.event_queue = event_queue
self.uid: UUID = uid if uid is not None else uuid4()
# Accept work
self.work = work

@abstractmethod
def get_events(self) -> Dict[socket.socket, int]:
Expand Down
34 changes: 17 additions & 17 deletions proxy/core/base/tcp_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class BaseTcpServerHandler(Work):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.must_flush_before_shutdown = False
logger.debug('Connection accepted from {0}'.format(self.client.addr))
logger.debug('Connection accepted from {0}'.format(self.work.addr))

@abstractmethod
def handle_data(self, data: memoryview) -> Optional[bool]:
Expand All @@ -57,14 +57,14 @@ def get_events(self) -> Dict[socket.socket, int]:
# We always want to read from client
# Register for EVENT_READ events
if self.must_flush_before_shutdown is False:
events[self.client.connection] = selectors.EVENT_READ
events[self.work.connection] = selectors.EVENT_READ
# If there is pending buffer for client
# also register for EVENT_WRITE events
if self.client.has_buffer():
if self.client.connection in events:
events[self.client.connection] |= selectors.EVENT_WRITE
if self.work.has_buffer():
if self.work.connection in events:
events[self.work.connection] |= selectors.EVENT_WRITE
else:
events[self.client.connection] = selectors.EVENT_WRITE
events[self.work.connection] = selectors.EVENT_WRITE
return events

def handle_events(
Expand All @@ -79,32 +79,32 @@ def handle_events(
if teardown:
logger.debug(
'Shutting down client {0} connection'.format(
self.client.addr,
self.work.addr,
),
)
return teardown

def handle_writables(self, writables: Writables) -> bool:
teardown = False
if self.client.connection in writables and self.client.has_buffer():
if self.work.connection in writables and self.work.has_buffer():
logger.debug(
'Flushing buffer to client {0}'.format(self.client.addr),
'Flushing buffer to client {0}'.format(self.work.addr),
)
self.client.flush()
self.work.flush()
if self.must_flush_before_shutdown is True:
if not self.client.has_buffer():
if not self.work.has_buffer():
teardown = True
self.must_flush_before_shutdown = False
return teardown

def handle_readables(self, readables: Readables) -> bool:
teardown = False
if self.client.connection in readables:
data = self.client.recv(self.flags.client_recvbuf_size)
if self.work.connection in readables:
data = self.work.recv(self.flags.client_recvbuf_size)
if data is None:
logger.debug(
'Connection closed by client {0}'.format(
self.client.addr,
self.work.addr,
),
)
teardown = True
Expand All @@ -113,13 +113,13 @@ def handle_readables(self, readables: Readables) -> bool:
if isinstance(r, bool) and r is True:
logger.debug(
'Implementation signaled shutdown for client {0}'.format(
self.client.addr,
self.work.addr,
),
)
if self.client.has_buffer():
if self.work.has_buffer():
logger.debug(
'Client {0} has pending buffer, will be flushed before shutting down'.format(
self.client.addr,
self.work.addr,
),
)
self.must_flush_before_shutdown = True
Expand Down
4 changes: 2 additions & 2 deletions proxy/core/base/tcp_tunnel.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def handle_data(self, data: memoryview) -> Optional[bool]:
pass # pragma: no cover

def initialize(self) -> None:
self.client.connection.setblocking(False)
self.work.connection.setblocking(False)

def shutdown(self) -> None:
if self.upstream:
Expand Down Expand Up @@ -87,7 +87,7 @@ def handle_events(
print('Connection closed by server')
return True
# tunnel data to client
self.client.queue(data)
self.work.queue(data)
if self.upstream and self.upstream.connection in writables:
self.upstream.flush()
return False
Expand Down
46 changes: 23 additions & 23 deletions proxy/http/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,25 +89,25 @@ def __init__(self, *args: Any, **kwargs: Any):

def initialize(self) -> None:
"""Optionally upgrades connection to HTTPS, set conn in non-blocking mode and initializes plugins."""
conn = self._optionally_wrap_socket(self.client.connection)
conn = self._optionally_wrap_socket(self.work.connection)
conn.setblocking(False)
# Update client connection reference if connection was wrapped
if self._encryption_enabled():
self.client = TcpClientConnection(conn=conn, addr=self.client.addr)
self.work = TcpClientConnection(conn=conn, addr=self.work.addr)
if b'HttpProtocolHandlerPlugin' in self.flags.plugins:
for klass in self.flags.plugins[b'HttpProtocolHandlerPlugin']:
instance: HttpProtocolHandlerPlugin = klass(
self.uid,
self.flags,
self.client,
self.work,
self.request,
self.event_queue,
)
self.plugins[instance.name()] = instance
logger.debug('Handling connection %r' % self.client.connection)
logger.debug('Handling connection %r' % self.work.connection)

def is_inactive(self) -> bool:
if not self.client.has_buffer() and \
if not self.work.has_buffer() and \
self._connection_inactive_for() > self.flags.timeout:
return True
return False
Expand All @@ -127,20 +127,20 @@ def shutdown(self) -> None:
logger.debug(
'Closing client connection %r '
'at address %r has buffer %s' %
(self.client.connection, self.client.addr, self.client.has_buffer()),
(self.work.connection, self.work.addr, self.work.has_buffer()),
)

conn = self.client.connection
conn = self.work.connection
# Unwrap if wrapped before shutdown.
if self._encryption_enabled() and \
isinstance(self.client.connection, ssl.SSLSocket):
conn = self.client.connection.unwrap()
isinstance(self.work.connection, ssl.SSLSocket):
conn = self.work.connection.unwrap()
conn.shutdown(socket.SHUT_WR)
logger.debug('Client connection shutdown successful')
except OSError:
pass
finally:
self.client.connection.close()
self.work.connection.close()
logger.debug('Client connection closed')
super().shutdown()

Expand Down Expand Up @@ -196,7 +196,7 @@ def handle_events(
def handle_data(self, data: memoryview) -> Optional[bool]:
if data is None:
logger.debug('Client closed connection, tearing down...')
self.client.closed = True
self.work.closed = True
return True

try:
Expand Down Expand Up @@ -227,7 +227,7 @@ def handle_data(self, data: memoryview) -> Optional[bool]:
logger.debug(
'Updated client conn to %s', upgraded_sock,
)
self.client._conn = upgraded_sock
self.work._conn = upgraded_sock
for plugin_ in self.plugins.values():
if plugin_ != plugin:
plugin_.client._conn = upgraded_sock
Expand All @@ -237,20 +237,20 @@ def handle_data(self, data: memoryview) -> Optional[bool]:
logger.debug('HttpProtocolException raised')
response: Optional[memoryview] = e.response(self.request)
if response:
self.client.queue(response)
self.work.queue(response)
return True
return False

def handle_writables(self, writables: Writables) -> bool:
if self.client.connection in writables and self.client.has_buffer():
if self.work.connection in writables and self.work.has_buffer():
logger.debug('Client is ready for writes, flushing buffer')
self.last_activity = time.time()

# TODO(abhinavsingh): This hook could just reside within server recv block
# instead of invoking when flushed to client.
#
# Invoke plugin.on_response_chunk
chunk = self.client.buffer
chunk = self.work.buffer
for plugin in self.plugins.values():
chunk = plugin.on_response_chunk(chunk)
if chunk is None:
Expand All @@ -272,7 +272,7 @@ def handle_writables(self, writables: Writables) -> bool:
return False

def handle_readables(self, readables: Readables) -> bool:
if self.client.connection in readables:
if self.work.connection in readables:
logger.debug('Client is ready for reads, reading')
self.last_activity = time.time()
try:
Expand All @@ -290,7 +290,7 @@ def handle_readables(self, readables: Readables) -> bool:
else:
logger.exception(
'Exception while receiving from %s connection %r with reason %r' %
(self.client.tag, self.client.connection, e),
(self.work.tag, self.work.connection, e),
)
return True
return False
Expand Down Expand Up @@ -324,7 +324,7 @@ def run(self) -> None:
except Exception as e:
logger.exception(
'Exception while handling connection %r' %
self.client.connection, exc_info=e,
self.work.connection, exc_info=e,
)
finally:
self.shutdown()
Expand Down Expand Up @@ -377,24 +377,24 @@ def _run_once(self) -> bool:

def _flush(self) -> None:
assert self.selector
if not self.client.has_buffer():
if not self.work.has_buffer():
return
try:
self.selector.register(
self.client.connection,
self.work.connection,
selectors.EVENT_WRITE,
)
while self.client.has_buffer():
while self.work.has_buffer():
ev: List[
Tuple[selectors.SelectorKey, int]
] = self.selector.select(timeout=1)
if len(ev) == 0:
continue
self.client.flush()
self.work.flush()
except BrokenPipeError:
pass
finally:
self.selector.unregister(self.client.connection)
self.selector.unregister(self.work.connection)

def _connection_inactive_for(self) -> float:
return time.time() - self.last_activity
12 changes: 6 additions & 6 deletions tests/http/exceptions/test_http_proxy_auth_failed.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,9 @@ def test_proxy_auth_fails_without_cred(self, mock_server_conn: mock.Mock) -> Non

self.protocol_handler._run_once()
mock_server_conn.assert_not_called()
self.assertEqual(self.protocol_handler.client.has_buffer(), True)
self.assertEqual(self.protocol_handler.work.has_buffer(), True)
self.assertEqual(
self.protocol_handler.client.buffer[0], ProxyAuthenticationFailed.RESPONSE_PKT,
self.protocol_handler.work.buffer[0], ProxyAuthenticationFailed.RESPONSE_PKT,
)
self._conn.send.assert_not_called()

Expand All @@ -92,9 +92,9 @@ def test_proxy_auth_fails_with_invalid_cred(self, mock_server_conn: mock.Mock) -

self.protocol_handler._run_once()
mock_server_conn.assert_not_called()
self.assertEqual(self.protocol_handler.client.has_buffer(), True)
self.assertEqual(self.protocol_handler.work.has_buffer(), True)
self.assertEqual(
self.protocol_handler.client.buffer[0], ProxyAuthenticationFailed.RESPONSE_PKT,
self.protocol_handler.work.buffer[0], ProxyAuthenticationFailed.RESPONSE_PKT,
)
self._conn.send.assert_not_called()

Expand All @@ -121,7 +121,7 @@ def test_proxy_auth_works_with_valid_cred(self, mock_server_conn: mock.Mock) ->

self.protocol_handler._run_once()
mock_server_conn.assert_called_once()
self.assertEqual(self.protocol_handler.client.has_buffer(), False)
self.assertEqual(self.protocol_handler.work.has_buffer(), False)

@mock.patch('proxy.http.proxy.server.TcpServerConnection')
def test_proxy_auth_works_with_mixed_case_basic_string(self, mock_server_conn: mock.Mock) -> None:
Expand All @@ -146,4 +146,4 @@ def test_proxy_auth_works_with_mixed_case_basic_string(self, mock_server_conn: m

self.protocol_handler._run_once()
mock_server_conn.assert_called_once()
self.assertEqual(self.protocol_handler.client.has_buffer(), False)
self.assertEqual(self.protocol_handler.work.has_buffer(), False)
2 changes: 1 addition & 1 deletion tests/http/test_http_proxy_tls_interception.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def mock_connection() -> Any:
)
self.assertEqual(self._conn.setblocking.call_count, 2)
self.assertEqual(
self.protocol_handler.client.connection,
self.protocol_handler.work.connection,
self.mock_ssl_wrap.return_value,
)

Expand Down
Loading