From a2895e42f1b3f3896f7c33f1f6ebbf6a33142755 Mon Sep 17 00:00:00 2001 From: Josh Snyder Date: Mon, 19 Apr 2021 15:53:55 -0700 Subject: [PATCH] bpo-37355: For nonblocking sockets, call SSL_read in a loop Continue looping until data is exhausted, and only then reacquire the GIL. This makes it possible to perform multi-threaded TLS downloads without saturating the GIL. On a test workload performing HTTPS download with 32 threads pinned to 16 cores, this produces a 4x speedup. before after wall clock time (s) : 29.637 7.116 user time (s) : 8.793 12.584 system time (s) : 105.118 30.010 voluntary switches : 1,653,065 248,484 speed (MB/s) : 4733 19712 --- Lib/test/test_ssl.py | 59 +++++++++++++++++-- .../2021-04-19-15-53-03.bpo-37355.3pie1n.rst | 3 + Modules/_ssl.c | 25 ++++++-- 3 files changed, 75 insertions(+), 12 deletions(-) create mode 100644 Misc/NEWS.d/next/Library/2021-04-19-15-53-03.bpo-37355.3pie1n.rst diff --git a/Lib/test/test_ssl.py b/Lib/test/test_ssl.py index 543d34a5469333..a3dcf9294f3e01 100644 --- a/Lib/test/test_ssl.py +++ b/Lib/test/test_ssl.py @@ -2349,6 +2349,50 @@ def test_bio_read_write_data(self): self.assertEqual(buf, b'foo\n') self.ssl_io_loop(sock, incoming, outgoing, sslobj.unwrap) + def test_bulk_nonblocking_read(self): + # 65536 bytes divide up into 4 TLS records (16 KB each) + # In nonblocking mode, we should be able to read all four in a single + # drop of the GIL. + size = 65536 + trips = [] + + client_context, server_context, hostname = testing_context() + server = ThreadedEchoServer(context=server_context, chatty=False, + buffer_size=size) + with server: + sock = socket.create_connection((HOST, server.port)) + sock.settimeout(0.0) + s = client_context.wrap_socket(sock, server_hostname=hostname, + do_handshake_on_connect=False) + + with s: + while True: + try: + s.do_handshake() + break + except ssl.SSLWantReadError: + select.select([s], [], []) + except ssl.SSLWantWriteError: + select.select([], [s], []) + + s.send(b'\x00' * size) + + select.select([s], [], []) + + while size > 0: + try: + count = len(s.recv(size)) + except ssl.SSLWantReadError: + select.select([s], [], []) + # Give the sender some more time to complete sending. + time.sleep(0.01) + else: + if count > 16384: + return + size -= count + + raise AssertionError("All TLS reads were smaller than 16KB") + @support.requires_resource('network') class NetworkedTests(unittest.TestCase): @@ -2408,7 +2452,7 @@ class ConnectionHandler(threading.Thread): with and without the SSL wrapper around the socket connection, so that we can test the STARTTLS functionality.""" - def __init__(self, server, connsock, addr): + def __init__(self, server, connsock, addr, buffer_size): self.server = server self.running = False self.sock = connsock @@ -2417,6 +2461,7 @@ def __init__(self, server, connsock, addr): self.sslconn = None threading.Thread.__init__(self) self.daemon = True + self.buffer_size = buffer_size def wrap_conn(self): try: @@ -2482,9 +2527,9 @@ def wrap_conn(self): def read(self): if self.sslconn: - return self.sslconn.read() + return self.sslconn.read(self.buffer_size) else: - return self.sock.recv(1024) + return self.sock.recv(self.buffer_size) def write(self, bytes): if self.sslconn: @@ -2602,8 +2647,8 @@ def run(self): def __init__(self, certificate=None, ssl_version=None, certreqs=None, cacerts=None, chatty=True, connectionchatty=False, starttls_server=False, - alpn_protocols=None, - ciphers=None, context=None): + alpn_protocols=None, ciphers=None, context=None, + buffer_size=1024): if context: self.context = context else: @@ -2632,6 +2677,7 @@ def __init__(self, certificate=None, ssl_version=None, self.conn_errors = [] threading.Thread.__init__(self) self.daemon = True + self.buffer_size = buffer_size def __enter__(self): self.start(threading.Event()) @@ -2659,7 +2705,8 @@ def run(self): if support.verbose and self.chatty: sys.stdout.write(' server: new connection from ' + repr(connaddr) + '\n') - handler = self.ConnectionHandler(self, newconn, connaddr) + handler = self.ConnectionHandler(self, newconn, connaddr, + self.buffer_size) handler.start() handler.join() except TimeoutError as e: diff --git a/Misc/NEWS.d/next/Library/2021-04-19-15-53-03.bpo-37355.3pie1n.rst b/Misc/NEWS.d/next/Library/2021-04-19-15-53-03.bpo-37355.3pie1n.rst new file mode 100644 index 00000000000000..be6ad12d4711b1 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2021-04-19-15-53-03.bpo-37355.3pie1n.rst @@ -0,0 +1,3 @@ +When reading from a nonblocking TLS socket, drop the GIL once to read up to +the entire buffer. Previously we would read at most one TLS record (16 KB). +Patch by Josh Snyder. diff --git a/Modules/_ssl.c b/Modules/_ssl.c index e67ab42050b26c..65e092c759d940 100644 --- a/Modules/_ssl.c +++ b/Modules/_ssl.c @@ -2444,10 +2444,11 @@ _ssl__SSLSocket_read_impl(PySSLSocket *self, Py_ssize_t len, PyObject *dest = NULL; char *mem; size_t count = 0; + size_t got = 0; int retval; int sockstate; _PySSLError err; - int nonblocking; + int nonblocking = 0; PySocketSockObject *sock = GET_SOCKET(self); _PyTime_t timeout, deadline = 0; int has_timeout; @@ -2507,11 +2508,23 @@ _ssl__SSLSocket_read_impl(PySSLSocket *self, Py_ssize_t len, do { PySSL_BEGIN_ALLOW_THREADS - retval = SSL_read_ex(self->ssl, mem, (size_t)len, &count); + do { + retval = SSL_read_ex(self->ssl, mem + got, len, &count); + if(retval <= 0) { + break; + } + + got += count; + len -= count; + } while(nonblocking && len > 0); err = _PySSL_errno(retval == 0, self->ssl, retval); PySSL_END_ALLOW_THREADS self->err = err; + if(got > 0) { + break; + } + if (PyErr_CheckSignals()) goto error; @@ -2526,7 +2539,7 @@ _ssl__SSLSocket_read_impl(PySSLSocket *self, Py_ssize_t len, } else if (err.ssl == SSL_ERROR_ZERO_RETURN && SSL_get_shutdown(self->ssl) == SSL_RECEIVED_SHUTDOWN) { - count = 0; + got = 0; goto done; } else @@ -2542,7 +2555,7 @@ _ssl__SSLSocket_read_impl(PySSLSocket *self, Py_ssize_t len, } while (err.ssl == SSL_ERROR_WANT_READ || err.ssl == SSL_ERROR_WANT_WRITE); - if (retval == 0) { + if (got == 0) { PySSL_SetError(self, retval, __FILE__, __LINE__); goto error; } @@ -2552,11 +2565,11 @@ _ssl__SSLSocket_read_impl(PySSLSocket *self, Py_ssize_t len, done: Py_XDECREF(sock); if (!group_right_1) { - _PyBytes_Resize(&dest, count); + _PyBytes_Resize(&dest, got); return dest; } else { - return PyLong_FromSize_t(count); + return PyLong_FromSize_t(got); } error: