From da23c7f70c76e9cfda41fea001ebcc91dae4b591 Mon Sep 17 00:00:00 2001 From: Abhinav Singh Date: Sun, 7 Nov 2021 05:02:52 +0530 Subject: [PATCH] Work (#693) * Refactor work acceptor and executor * Lint fixes * Fix expression-not-assigned pylint error --- proxy/core/acceptor/acceptor.py | 100 ++++++++++++++++++------------ proxy/core/acceptor/threadless.py | 40 ++++++------ proxy/core/base/tcp_server.py | 8 --- proxy/http/handler.py | 4 +- proxy/http/proxy/server.py | 54 +++++++++------- 5 files changed, 115 insertions(+), 91 deletions(-) diff --git a/proxy/core/acceptor/acceptor.py b/proxy/core/acceptor/acceptor.py index 78e0f02bd9..985b9860cc 100644 --- a/proxy/core/acceptor/acceptor.py +++ b/proxy/core/acceptor/acceptor.py @@ -42,19 +42,25 @@ class Acceptor(multiprocessing.Process): - """Socket server acceptor process. + """Work acceptor process. - Accepts a server socket fd over `work_queue` and start listening for client - connections over the passed server socket. By default, it spawns a separate thread - to handle each client request. + On start-up, `Acceptor` accepts a file descriptor which will be used to + accept new work. File descriptor is accepted over a `work_queue` which is + closed immediately after receiving the descriptor. - However, if `--threadless` option is enabled, Acceptor process will also pre-spawns a `Threadless` - process at startup. Accepted client connections are then passed to the `Threadless` process - which internally uses asyncio event loop to handle client connections. + `Acceptor` goes on to listen for new work over the received server socket. + By default, `Acceptor` will spawn a new thread to handle each work. - TODO(abhinavsingh): Instead of starting `Threadless` process, can we work with a `Threadless` thread? - What are the performance implications of sharing fds between threads vs processes? How much performance - degradation happen when processes are running on separate CPU cores? + However, when `--threadless` option is enabled, `Acceptor` process will also pre-spawns a + `Threadless` process during start-up. Accepted work is passed to these `Threadless` processes. + `Acceptor` process shares accepted work with a `Threadless` process over it's dedicated pipe. + + TODO(abhinavsingh): Open questions: + 1) Instead of starting `Threadless` process, can we work with a `Threadless` thread? + 2) What are the performance implications of sharing fds between threads vs processes? + 3) How much performance degradation happens when acceptor and threadless processes are + running on separate CPU cores? + 4) Can we ensure both acceptor and threadless process are pinned to the same CPU core? """ def __init__( @@ -67,18 +73,26 @@ def __init__( event_queue: Optional[EventQueue] = None, ) -> None: super().__init__() + self.flags = flags + # Lock shared by all acceptor processes + # to avoid concurrent accept over server socket + self.lock = lock + # Index assigned by `AcceptorPool` self.idd = idd + # Queue over which server socket fd is received on start-up self.work_queue: connection.Connection = work_queue - self.flags = flags + # Worker class self.work_klass = work_klass - self.lock = lock + # Eventing core queue self.event_queue = event_queue - + # Selector & threadless states self.running = multiprocessing.Event() self.selector: Optional[selectors.DefaultSelector] = None - self.sock: Optional[socket.socket] = None self.threadless_process: Optional[Threadless] = None self.threadless_client_queue: Optional[connection.Connection] = None + # File descriptor used to accept new work + # Currently, a socket fd is assumed. + self.sock: Optional[socket.socket] = None def start_threadless_process(self) -> None: pipe = multiprocessing.Pipe() @@ -99,31 +113,30 @@ def shutdown_threadless_process(self) -> None: self.threadless_process.join() self.threadless_client_queue.close() - def start_work(self, conn: socket.socket, addr: Tuple[str, int]) -> None: - if self.flags.threadless and \ - self.threadless_client_queue and \ - self.threadless_process: - self.threadless_client_queue.send(addr) - send_handle( - self.threadless_client_queue, - conn.fileno(), - self.threadless_process.pid, - ) - conn.close() - else: - work = self.work_klass( - TcpClientConnection(conn, addr), - flags=self.flags, - event_queue=self.event_queue, - ) - work_thread = threading.Thread(target=work.run) - work_thread.daemon = True - work.publish_event( - event_name=eventNames.WORK_STARTED, - event_payload={'fileno': conn.fileno(), 'addr': addr}, - publisher_id=self.__class__.__name__, - ) - work_thread.start() + def _start_threadless_work(self, conn: socket.socket, addr: Tuple[str, int]) -> None: + assert self.threadless_process and self.threadless_client_queue + self.threadless_client_queue.send(addr) + send_handle( + self.threadless_client_queue, + conn.fileno(), + self.threadless_process.pid, + ) + conn.close() + + def _start_threaded_work(self, conn: socket.socket, addr: Tuple[str, int]) -> None: + work = self.work_klass( + TcpClientConnection(conn, addr), + flags=self.flags, + event_queue=self.event_queue, + ) + work_thread = threading.Thread(target=work.run) + work_thread.daemon = True + work.publish_event( + event_name=eventNames.WORK_STARTED, + event_payload={'fileno': conn.fileno(), 'addr': addr}, + publisher_id=self.__class__.__name__, + ) + work_thread.start() def run_once(self) -> None: with self.lock: @@ -132,7 +145,14 @@ def run_once(self) -> None: if len(events) == 0: return conn, addr = self.sock.accept() - self.start_work(conn, addr) + if ( + self.flags.threadless and + self.threadless_client_queue and + self.threadless_process + ): + self._start_threadless_work(conn, addr) + else: + self._start_threaded_work(conn, addr) def run(self) -> None: setup_logger( diff --git a/proxy/core/acceptor/threadless.py b/proxy/core/acceptor/threadless.py index d50d311d38..eaac4fb83e 100644 --- a/proxy/core/acceptor/threadless.py +++ b/proxy/core/acceptor/threadless.py @@ -34,21 +34,22 @@ class Threadless(multiprocessing.Process): - """Threadless process provides an event loop. + """Work executor process. - Internally, for each client connection, an instance of `work_klass` - is created. Threadless will invoke necessary lifecycle of the `Work` class - allowing implementations to handle accepted client connections as they wish. + Threadless process provides an event loop, which is shared across + multiple `Work` instances to handle work. - Note that, all `Work` implementations share the same underlying event loop. + Threadless takes input a `work_klass` and an `event_queue`. `work_klass` + must conform to the `Work` protocol. Work is received over the + `event_queue`. - When --threadless option is enabled, each Acceptor process also - spawns one Threadless process. And instead of spawning new thread - for each accepted client connection, Acceptor process sends - accepted client connection to Threadless process over a pipe. + When a work is accepted, threadless creates a new instance of `work_klass`. + Threadless will then invoke necessary lifecycle of the `Work` protocol, + allowing `work_klass` implementation to handle the assigned work. - Example, HttpProtocolHandler implements Work class to hooks into the - event loop provided by Threadless process. + Example, `BaseTcpServerHandler` implements `Work` protocol. It expects + a client connection as work payload and hooks into the threadless + event loop to handle the client connection. """ def __init__( @@ -82,13 +83,10 @@ def selected_events(self) -> Generator[ for fd in worker_events: # Can throw ValueError: Invalid file descriptor: -1 # - # Work classes must handle the exception and shutdown - # gracefully otherwise this will result in bringing down the - # entire threadless process - # - # This is only possible when work.get_events pass - # an invalid file descriptor. Example, because of bad - # exception handling within the work implementation class. + # A guard within Work classes may not help here due to + # asynchronous nature. Hence, threadless will handle + # ValueError exceptions raised by selector.register + # for invalid fd. self.selector.register(fd, worker_events[fd]) ev = self.selector.select(timeout=1) readables = [] @@ -180,6 +178,10 @@ def run_once(self) -> None: # Note that selector from now on is idle, # until all the logic below completes. # + # This is where one process per CPU architecture shines, + # as other threadless processes can continue process work + # within their context. + # # Invoke Threadless.handle_events # # TODO: Only send readable / writables that client originally @@ -194,7 +196,7 @@ def run_once(self) -> None: self.accept_client() # Wait for Threadless.handle_events to complete self.loop.run_until_complete(self.wait_for_tasks(tasks)) - # Remove and shutdown inactive connections + # Remove and shutdown inactive workers self.cleanup_inactive() def run(self) -> None: diff --git a/proxy/core/base/tcp_server.py b/proxy/core/base/tcp_server.py index 73e3fdeb31..9931ae34e7 100644 --- a/proxy/core/base/tcp_server.py +++ b/proxy/core/base/tcp_server.py @@ -102,7 +102,6 @@ def handle_readables(self, readables: Readables) -> bool: if self.client.connection in readables: data = self.client.recv(self.flags.client_recvbuf_size) if data is None: - # Client closed connection, signal shutdown logger.debug( 'Connection closed by client {0}'.format( self.client.addr, @@ -126,11 +125,4 @@ def handle_readables(self, readables: Readables) -> bool: self.must_flush_before_shutdown = True else: teardown = True - # except ConnectionResetError: - # logger.debug( - # 'Connection reset by client {0}'.format( - # self.client.addr, - # ), - # ) - # teardown = True return teardown diff --git a/proxy/http/handler.py b/proxy/http/handler.py index 1c2dd83189..5fe545279c 100644 --- a/proxy/http/handler.py +++ b/proxy/http/handler.py @@ -234,9 +234,7 @@ def handle_data(self, data: memoryview) -> Optional[bool]: elif isinstance(upgraded_sock, bool) and upgraded_sock is True: return True except HttpProtocolException as e: - logger.debug( - 'HttpProtocolException type raised', - ) + logger.debug('HttpProtocolException raised') response: Optional[memoryview] = e.response(self.request) if response: self.client.queue(response) diff --git a/proxy/http/proxy/server.py b/proxy/http/proxy/server.py index 113fe6599d..43d1b67377 100644 --- a/proxy/http/proxy/server.py +++ b/proxy/http/proxy/server.py @@ -151,10 +151,18 @@ def get_descriptors( r: List[socket.socket] = [] w: List[socket.socket] = [] - if self.upstream and not self.upstream.closed and self.upstream.connection: + if ( + self.upstream and + not self.upstream.closed and + self.upstream.connection + ): r.append(self.upstream.connection) - if self.upstream and not self.upstream.closed and \ - self.upstream.has_buffer() and self.upstream.connection: + if ( + self.upstream and + not self.upstream.closed and + self.upstream.has_buffer() and + self.upstream.connection + ): w.append(self.upstream.connection) # TODO(abhinavsingh): We need to keep a mapping of plugin and @@ -658,19 +666,11 @@ def generate_upstream_certificate( def intercept(self) -> Union[socket.socket, bool]: # Perform SSL/TLS handshake with upstream - teardown = self.wrap_server() - if teardown: - raise HttpProtocolException( - 'Exception when wrapping server for interception', - ) + self.wrap_server() # Generate certificate and perform handshake with client # wrap_client also flushes client data before wrapping # sending to client can raise, handle expected exceptions - teardown = self.wrap_client() - if teardown: - raise HttpProtocolException( - 'Exception when wrapping client for interception', - ) + self.wrap_client() # Update all plugin connection reference # TODO(abhinavsingh): Is this required? for plugin in self.plugins.values(): @@ -680,6 +680,7 @@ def intercept(self) -> Union[socket.socket, bool]: def wrap_server(self) -> bool: assert self.upstream is not None assert isinstance(self.upstream.connection, socket.socket) + do_close = False try: self.upstream.wrap(text_(self.request.host), self.flags.ca_file) except ssl.SSLCertVerificationError: # Server raised certificate verification error @@ -692,7 +693,7 @@ def wrap_server(self) -> bool: self.upstream.addr[0], ), ) - return True + do_close = True except ssl.SSLError as e: if e.reason == 'SSLV3_ALERT_HANDSHAKE_FAILURE': logger.warning( @@ -707,13 +708,19 @@ def wrap_server(self) -> bool: self.upstream.addr[0], ), exc_info=e, ) - return True + do_close = True + finally: + if do_close: + raise HttpProtocolException( + 'Exception when wrapping server for interception', + ) assert isinstance(self.upstream.connection, ssl.SSLSocket) return False def wrap_client(self) -> bool: assert self.upstream is not None and self.flags.ca_signing_key_file is not None assert isinstance(self.upstream.connection, ssl.SSLSocket) + do_close = False try: # TODO: Perform async certificate generation generated_cert = self.generate_upstream_certificate( @@ -724,7 +731,7 @@ def wrap_client(self) -> bool: logger.exception( 'TimeoutExpired during certificate generation', exc_info=e, ) - return True + do_close = True except ssl.SSLCertVerificationError: # Client raised certificate verification error # When --disable-interception-on-ssl-cert-verification-error flag is on, # we will cache such upstream hosts and avoid intercepting them for future @@ -735,14 +742,14 @@ def wrap_client(self) -> bool: self.upstream.addr[0], ), ) - return True + do_close = True except ssl.SSLEOFError as e: logger.warning( 'ssl.SSLEOFError {0} when wrapping client for upstream: {1}'.format( str(e), self.upstream.addr[0], ), ) - return True + do_close = True except ssl.SSLError as e: if e.reason in ('TLSV1_ALERT_UNKNOWN_CA', 'UNSUPPORTED_PROTOCOL'): logger.warning( @@ -757,21 +764,26 @@ def wrap_client(self) -> bool: self.upstream.addr[0], ), exc_info=e, ) - return True + do_close = True except BrokenPipeError: logger.error( 'BrokenPipeError when wrapping client for upstream: {0}'.format( self.upstream.addr[0], ), ) - return True + do_close = True except OSError as e: logger.exception( 'OSError when wrapping client for upstream: {0}'.format( self.upstream.addr[0], ), exc_info=e, ) - return True + do_close = True + finally: + if do_close: + raise HttpProtocolException( + 'Exception when wrapping client for interception', + ) logger.debug('TLS intercepting using %s', generated_cert) return False