Skip to content

Commit

Permalink
Decouple SSL wrap logic into connection classes (#394)
Browse files Browse the repository at this point in the history
* Move wrap functionality within respective connection classes. Also decouple websocket client handshake method

* Add a TCP echo client example that works with TCP echo server example
  • Loading branch information
abhinavsingh authored Jul 8, 2020
1 parent c884338 commit 682114e
Show file tree
Hide file tree
Showing 9 changed files with 94 additions and 34 deletions.
21 changes: 21 additions & 0 deletions examples/tcp_echo_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# -*- coding: utf-8 -*-
"""
proxy.py
~~~~~~~~
⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on
Network monitoring, controls & Application development, testing, debugging.
:copyright: (c) 2013-present by Abhinav Singh and contributors.
:license: BSD, see LICENSE for more details.
"""
from proxy.common.utils import socket_connection
from proxy.common.constants import DEFAULT_BUFFER_SIZE

if __name__ == '__main__':
with socket_connection(('::', 12345)) as client:
while True:
client.send(b'hello')
data = client.recv(DEFAULT_BUFFER_SIZE)
if data is None:
break
print(data)
27 changes: 20 additions & 7 deletions examples/tcp_echo_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import socket
import selectors

from typing import Dict
from typing import Dict, Any

from proxy.core.acceptor import AcceptorPool, Work
from proxy.common.flags import Flags
Expand All @@ -29,6 +29,10 @@ class EchoServerHandler(Work):
intialize, is_inactive and shutdown method.
"""

def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
print('Connection accepted from {0}'.format(self.client.addr))

def get_events(self) -> Dict[socket.socket, int]:
# We always want to read from client
# Register for EVENT_READ events
Expand All @@ -45,12 +49,21 @@ def handle_events(
writables: Writables) -> bool:
"""Return True to shutdown work."""
if self.client.connection in readables:
data = self.client.recv()
if data is None:
# Client closed connection, signal shutdown
try:
data = self.client.recv()
if data is None:
# Client closed connection, signal shutdown
print(
'Connection closed by client {0}'.format(
self.client.addr))
return True
# Echo data back to client
self.client.queue(data)
except ConnectionResetError:
print(
'Connection reset by client {0}'.format(
self.client.addr))
return True
# Queue data back to client
self.client.queue(data)

if self.client.connection in writables:
self.client.flush()
Expand All @@ -61,7 +74,7 @@ def handle_events(
def main() -> None:
# This example requires `threadless=True`
pool = AcceptorPool(
flags=Flags(num_workers=1, threadless=True),
flags=Flags(port=12345, num_workers=1, threadless=True),
work_klass=EchoServerHandler)
try:
pool.setup()
Expand Down
12 changes: 9 additions & 3 deletions examples/websocket_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@
def on_message(frame: WebsocketFrame) -> None:
"""WebsocketClient on_message callback."""
global client, num_echos, last_dispatch_time
print('Received %r after %d millisec' % (frame.data, (time.time() - last_dispatch_time) * 1000))
assert(frame.data == b'hello' and frame.opcode == websocketOpcodes.TEXT_FRAME)
print('Received %r after %d millisec' %
(frame.data, (time.time() - last_dispatch_time) * 1000))
assert(frame.data == b'hello' and frame.opcode ==
websocketOpcodes.TEXT_FRAME)
if num_echos > 0:
client.queue(static_frame)
last_dispatch_time = time.time()
Expand All @@ -34,7 +36,11 @@ def on_message(frame: WebsocketFrame) -> None:

if __name__ == '__main__':
# Constructor establishes socket connection
client = WebsocketClient(b'echo.websocket.org', 80, b'/', on_message=on_message)
client = WebsocketClient(
b'echo.websocket.org',
80,
b'/',
on_message=on_message)
# Perform handshake
client.handshake()
# Queue some data for client
Expand Down
12 changes: 12 additions & 0 deletions proxy/core/connection/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,15 @@ def connection(self) -> Union[ssl.SSLSocket, socket.socket]:
if self._conn is None:
raise TcpConnectionUninitializedException()
return self._conn

def wrap(self, keyfile: str, certfile: str) -> None:
self.connection.setblocking(True)
self.flush()
self._conn = ssl.wrap_socket(
self.connection,
server_side=True,
# ca_certs=self.flags.ca_cert_file,
certfile=certfile,
keyfile=keyfile,
ssl_version=ssl.PROTOCOL_TLS)
self.connection.setblocking(False)
13 changes: 12 additions & 1 deletion proxy/core/connection/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
:copyright: (c) 2013-present by Abhinav Singh and contributors.
:license: BSD, see LICENSE for more details.
"""
import socket
import ssl
import socket
from typing import Optional, Union, Tuple

from .connection import TcpConnection, tcpConnectionTypes, TcpConnectionUninitializedException
Expand All @@ -34,3 +34,14 @@ def connect(self) -> None:
if self._conn is not None:
return
self._conn = new_socket_connection(self.addr)

def wrap(self, hostname: str, ca_file: Optional[str]) -> None:
ctx = ssl.create_default_context(
ssl.Purpose.SERVER_AUTH, cafile=ca_file)
ctx.options |= ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3 | ssl.OP_NO_TLSv1
ctx.check_hostname = True
self.connection.setblocking(True)
self._conn = ctx.wrap_socket(
self.connection,
server_hostname=hostname)
self.connection.setblocking(False)
24 changes: 4 additions & 20 deletions proxy/http/proxy/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,31 +453,15 @@ def generate_upstream_certificate(
def wrap_server(self) -> None:
assert self.server is not None
assert isinstance(self.server.connection, socket.socket)
ctx = ssl.create_default_context(
ssl.Purpose.SERVER_AUTH, cafile=self.flags.ca_file)
ctx.options |= ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3 | ssl.OP_NO_TLSv1
ctx.check_hostname = True
self.server.connection.setblocking(True)
self.server._conn = ctx.wrap_socket(
self.server.connection,
server_hostname=text_(self.request.host))
self.server.connection.setblocking(False)
self.server.wrap(text_(self.request.host), self.flags.ca_file)
assert isinstance(self.server.connection, ssl.SSLSocket)

def wrap_client(self) -> None:
assert self.server is not None
assert self.server is not None and self.flags.ca_signing_key_file is not None
assert isinstance(self.server.connection, ssl.SSLSocket)
generated_cert = self.generate_upstream_certificate(
cast(Dict[str, Any], self.server.connection.getpeercert()))
self.client.connection.setblocking(True)
self.client.flush()
self.client._conn = ssl.wrap_socket(
self.client.connection,
server_side=True,
# ca_certs=self.flags.ca_cert_file,
certfile=generated_cert,
keyfile=self.flags.ca_signing_key_file,
ssl_version=ssl.PROTOCOL_TLS)
self.client.connection.setblocking(False)
self.client.wrap(self.flags.ca_signing_key_file, generated_cert)
logger.debug(
'TLS interception using %s', generated_cert)

Expand Down
6 changes: 5 additions & 1 deletion proxy/http/websocket/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,11 @@ def handshake(self) -> None:

def upgrade(self) -> None:
key = base64.b64encode(secrets.token_bytes(16))
self.sock.send(build_websocket_handshake_request(key, url=self.path, host=self.hostname))
self.sock.send(
build_websocket_handshake_request(
key,
url=self.path,
host=self.hostname))
response = HttpParser(httpParserTypes.RESPONSE_PARSER)
response.parse(self.sock.recv(DEFAULT_BUFFER_SIZE))
accept = response.header(b'Sec-Websocket-Accept')
Expand Down
7 changes: 6 additions & 1 deletion tests/http/test_http_proxy_tls_interception.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from typing import Any
from unittest import mock

from proxy.core.connection import TcpClientConnection
from proxy.core.connection import TcpClientConnection, TcpServerConnection
from proxy.http.handler import HttpProtocolHandler
from proxy.http.proxy import HttpProxyPlugin
from proxy.http.methods import httpMethods
Expand Down Expand Up @@ -71,6 +71,11 @@ def mock_connection() -> Any:
return ssl_connection
return plain_connection

# Do not mock the original wrap method
self.mock_server_conn.return_value.wrap.side_effect = \
lambda x, y: TcpServerConnection.wrap(
self.mock_server_conn.return_value, x, y)

type(self.mock_server_conn.return_value).connection = \
mock.PropertyMock(side_effect=mock_connection)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from proxy.common.utils import bytes_
from proxy.common.flags import Flags
from proxy.common.utils import build_http_request, build_http_response
from proxy.core.connection import TcpClientConnection
from proxy.core.connection import TcpClientConnection, TcpServerConnection
from proxy.http.codes import httpStatusCodes
from proxy.http.methods import httpMethods
from proxy.http.handler import HttpProtocolHandler
Expand Down Expand Up @@ -98,6 +98,10 @@ def mock_connection() -> Any:
return self.server_ssl_connection
return self._conn

# Do not mock the original wrap method
self.server.wrap.side_effect = \
lambda x, y: TcpServerConnection.wrap(self.server, x, y)

self.server.has_buffer.side_effect = has_buffer
type(self.server).closed = mock.PropertyMock(side_effect=closed)
type(
Expand Down

0 comments on commit 682114e

Please sign in to comment.