Skip to content

Commit

Permalink
bpo-44011: New asyncio ssl implementation (#17975)
Browse files Browse the repository at this point in the history
  • Loading branch information
asvetlov authored May 2, 2021
1 parent c96cc08 commit 5fb06ed
Show file tree
Hide file tree
Showing 12 changed files with 2,474 additions and 527 deletions.
43 changes: 34 additions & 9 deletions Lib/asyncio/base_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ async def restore(self):
class Server(events.AbstractServer):

def __init__(self, loop, sockets, protocol_factory, ssl_context, backlog,
ssl_handshake_timeout):
ssl_handshake_timeout, ssl_shutdown_timeout=None):
self._loop = loop
self._sockets = sockets
self._active_count = 0
Expand All @@ -282,6 +282,7 @@ def __init__(self, loop, sockets, protocol_factory, ssl_context, backlog,
self._backlog = backlog
self._ssl_context = ssl_context
self._ssl_handshake_timeout = ssl_handshake_timeout
self._ssl_shutdown_timeout = ssl_shutdown_timeout
self._serving = False
self._serving_forever_fut = None

Expand Down Expand Up @@ -313,7 +314,8 @@ def _start_serving(self):
sock.listen(self._backlog)
self._loop._start_serving(
self._protocol_factory, sock, self._ssl_context,
self, self._backlog, self._ssl_handshake_timeout)
self, self._backlog, self._ssl_handshake_timeout,
self._ssl_shutdown_timeout)

def get_loop(self):
return self._loop
Expand Down Expand Up @@ -467,6 +469,7 @@ def _make_ssl_transport(
*, server_side=False, server_hostname=None,
extra=None, server=None,
ssl_handshake_timeout=None,
ssl_shutdown_timeout=None,
call_connection_made=True):
"""Create SSL transport."""
raise NotImplementedError
Expand Down Expand Up @@ -969,6 +972,7 @@ async def create_connection(
proto=0, flags=0, sock=None,
local_addr=None, server_hostname=None,
ssl_handshake_timeout=None,
ssl_shutdown_timeout=None,
happy_eyeballs_delay=None, interleave=None):
"""Connect to a TCP server.
Expand Down Expand Up @@ -1004,6 +1008,10 @@ async def create_connection(
raise ValueError(
'ssl_handshake_timeout is only meaningful with ssl')

if ssl_shutdown_timeout is not None and not ssl:
raise ValueError(
'ssl_shutdown_timeout is only meaningful with ssl')

if happy_eyeballs_delay is not None and interleave is None:
# If using happy eyeballs, default to interleave addresses by family
interleave = 1
Expand Down Expand Up @@ -1079,7 +1087,8 @@ async def create_connection(

transport, protocol = await self._create_connection_transport(
sock, protocol_factory, ssl, server_hostname,
ssl_handshake_timeout=ssl_handshake_timeout)
ssl_handshake_timeout=ssl_handshake_timeout,
ssl_shutdown_timeout=ssl_shutdown_timeout)
if self._debug:
# Get the socket from the transport because SSL transport closes
# the old socket and creates a new SSL socket
Expand All @@ -1091,7 +1100,8 @@ async def create_connection(
async def _create_connection_transport(
self, sock, protocol_factory, ssl,
server_hostname, server_side=False,
ssl_handshake_timeout=None):
ssl_handshake_timeout=None,
ssl_shutdown_timeout=None):

sock.setblocking(False)

Expand All @@ -1102,7 +1112,8 @@ async def _create_connection_transport(
transport = self._make_ssl_transport(
sock, protocol, sslcontext, waiter,
server_side=server_side, server_hostname=server_hostname,
ssl_handshake_timeout=ssl_handshake_timeout)
ssl_handshake_timeout=ssl_handshake_timeout,
ssl_shutdown_timeout=ssl_shutdown_timeout)
else:
transport = self._make_socket_transport(sock, protocol, waiter)

Expand Down Expand Up @@ -1193,7 +1204,8 @@ async def _sendfile_fallback(self, transp, file, offset, count):
async def start_tls(self, transport, protocol, sslcontext, *,
server_side=False,
server_hostname=None,
ssl_handshake_timeout=None):
ssl_handshake_timeout=None,
ssl_shutdown_timeout=None):
"""Upgrade transport to TLS.
Return a new transport that *protocol* should start using
Expand All @@ -1216,6 +1228,7 @@ async def start_tls(self, transport, protocol, sslcontext, *,
self, protocol, sslcontext, waiter,
server_side, server_hostname,
ssl_handshake_timeout=ssl_handshake_timeout,
ssl_shutdown_timeout=ssl_shutdown_timeout,
call_connection_made=False)

# Pause early so that "ssl_protocol.data_received()" doesn't
Expand Down Expand Up @@ -1414,6 +1427,7 @@ async def create_server(
reuse_address=None,
reuse_port=None,
ssl_handshake_timeout=None,
ssl_shutdown_timeout=None,
start_serving=True):
"""Create a TCP server.
Expand All @@ -1437,6 +1451,10 @@ async def create_server(
raise ValueError(
'ssl_handshake_timeout is only meaningful with ssl')

if ssl_shutdown_timeout is not None and ssl is None:
raise ValueError(
'ssl_shutdown_timeout is only meaningful with ssl')

if host is not None or port is not None:
if sock is not None:
raise ValueError(
Expand Down Expand Up @@ -1509,7 +1527,8 @@ async def create_server(
sock.setblocking(False)

server = Server(self, sockets, protocol_factory,
ssl, backlog, ssl_handshake_timeout)
ssl, backlog, ssl_handshake_timeout,
ssl_shutdown_timeout)
if start_serving:
server._start_serving()
# Skip one loop iteration so that all 'loop.add_reader'
Expand All @@ -1523,17 +1542,23 @@ async def create_server(
async def connect_accepted_socket(
self, protocol_factory, sock,
*, ssl=None,
ssl_handshake_timeout=None):
ssl_handshake_timeout=None,
ssl_shutdown_timeout=None):
if sock.type != socket.SOCK_STREAM:
raise ValueError(f'A Stream Socket was expected, got {sock!r}')

if ssl_handshake_timeout is not None and not ssl:
raise ValueError(
'ssl_handshake_timeout is only meaningful with ssl')

if ssl_shutdown_timeout is not None and not ssl:
raise ValueError(
'ssl_shutdown_timeout is only meaningful with ssl')

transport, protocol = await self._create_connection_transport(
sock, protocol_factory, ssl, '', server_side=True,
ssl_handshake_timeout=ssl_handshake_timeout)
ssl_handshake_timeout=ssl_handshake_timeout,
ssl_shutdown_timeout=ssl_shutdown_timeout)
if self._debug:
# Get the socket from the transport because SSL transport closes
# the old socket and creates a new SSL socket
Expand Down
7 changes: 7 additions & 0 deletions Lib/asyncio/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,17 @@
# The default timeout matches that of Nginx.
SSL_HANDSHAKE_TIMEOUT = 60.0

# Number of seconds to wait for SSL shutdown to complete
# The default timeout mimics lingering_time
SSL_SHUTDOWN_TIMEOUT = 30.0

# Used in sendfile fallback code. We use fallback for platforms
# that don't support sendfile, or for TLS connections.
SENDFILE_FALLBACK_READBUFFER_SIZE = 1024 * 256

FLOW_CONTROL_HIGH_WATER_SSL_READ = 256 # KiB
FLOW_CONTROL_HIGH_WATER_SSL_WRITE = 512 # KiB

# The enum should be here to break circular dependencies between
# base_events and sslproto
class _SendfileMode(enum.Enum):
Expand Down
19 changes: 16 additions & 3 deletions Lib/asyncio/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,7 @@ async def create_connection(
flags=0, sock=None, local_addr=None,
server_hostname=None,
ssl_handshake_timeout=None,
ssl_shutdown_timeout=None,
happy_eyeballs_delay=None, interleave=None):
raise NotImplementedError

Expand All @@ -313,6 +314,7 @@ async def create_server(
flags=socket.AI_PASSIVE, sock=None, backlog=100,
ssl=None, reuse_address=None, reuse_port=None,
ssl_handshake_timeout=None,
ssl_shutdown_timeout=None,
start_serving=True):
"""A coroutine which creates a TCP server bound to host and port.
Expand Down Expand Up @@ -353,6 +355,10 @@ async def create_server(
will wait for completion of the SSL handshake before aborting the
connection. Default is 60s.
ssl_shutdown_timeout is the time in seconds that an SSL server
will wait for completion of the SSL shutdown procedure
before aborting the connection. Default is 30s.
start_serving set to True (default) causes the created server
to start accepting connections immediately. When set to False,
the user should await Server.start_serving() or Server.serve_forever()
Expand All @@ -371,7 +377,8 @@ async def sendfile(self, transport, file, offset=0, count=None,
async def start_tls(self, transport, protocol, sslcontext, *,
server_side=False,
server_hostname=None,
ssl_handshake_timeout=None):
ssl_handshake_timeout=None,
ssl_shutdown_timeout=None):
"""Upgrade a transport to TLS.
Return a new transport that *protocol* should start using
Expand All @@ -383,13 +390,15 @@ async def create_unix_connection(
self, protocol_factory, path=None, *,
ssl=None, sock=None,
server_hostname=None,
ssl_handshake_timeout=None):
ssl_handshake_timeout=None,
ssl_shutdown_timeout=None):
raise NotImplementedError

async def create_unix_server(
self, protocol_factory, path=None, *,
sock=None, backlog=100, ssl=None,
ssl_handshake_timeout=None,
ssl_shutdown_timeout=None,
start_serving=True):
"""A coroutine which creates a UNIX Domain Socket server.
Expand All @@ -411,6 +420,9 @@ async def create_unix_server(
ssl_handshake_timeout is the time in seconds that an SSL server
will wait for the SSL handshake to complete (defaults to 60s).
ssl_shutdown_timeout is the time in seconds that an SSL server
will wait for the SSL shutdown to finish (defaults to 30s).
start_serving set to True (default) causes the created server
to start accepting connections immediately. When set to False,
the user should await Server.start_serving() or Server.serve_forever()
Expand All @@ -421,7 +433,8 @@ async def create_unix_server(
async def connect_accepted_socket(
self, protocol_factory, sock,
*, ssl=None,
ssl_handshake_timeout=None):
ssl_handshake_timeout=None,
ssl_shutdown_timeout=None):
"""Handle an accepted connection.
This is used by servers that accept connections outside of
Expand Down
12 changes: 8 additions & 4 deletions Lib/asyncio/proactor_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,11 +642,13 @@ def _make_ssl_transport(
self, rawsock, protocol, sslcontext, waiter=None,
*, server_side=False, server_hostname=None,
extra=None, server=None,
ssl_handshake_timeout=None):
ssl_handshake_timeout=None,
ssl_shutdown_timeout=None):
ssl_protocol = sslproto.SSLProtocol(
self, protocol, sslcontext, waiter,
server_side, server_hostname,
ssl_handshake_timeout=ssl_handshake_timeout)
ssl_handshake_timeout=ssl_handshake_timeout,
ssl_shutdown_timeout=ssl_shutdown_timeout)
_ProactorSocketTransport(self, rawsock, ssl_protocol,
extra=extra, server=server)
return ssl_protocol._app_transport
Expand Down Expand Up @@ -812,7 +814,8 @@ def _write_to_self(self):

def _start_serving(self, protocol_factory, sock,
sslcontext=None, server=None, backlog=100,
ssl_handshake_timeout=None):
ssl_handshake_timeout=None,
ssl_shutdown_timeout=None):

def loop(f=None):
try:
Expand All @@ -826,7 +829,8 @@ def loop(f=None):
self._make_ssl_transport(
conn, protocol, sslcontext, server_side=True,
extra={'peername': addr}, server=server,
ssl_handshake_timeout=ssl_handshake_timeout)
ssl_handshake_timeout=ssl_handshake_timeout,
ssl_shutdown_timeout=ssl_shutdown_timeout)
else:
self._make_socket_transport(
conn, protocol,
Expand Down
31 changes: 20 additions & 11 deletions Lib/asyncio/selector_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,15 @@ def _make_ssl_transport(
self, rawsock, protocol, sslcontext, waiter=None,
*, server_side=False, server_hostname=None,
extra=None, server=None,
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT,
ssl_shutdown_timeout=constants.SSL_SHUTDOWN_TIMEOUT,
):
ssl_protocol = sslproto.SSLProtocol(
self, protocol, sslcontext, waiter,
server_side, server_hostname,
ssl_handshake_timeout=ssl_handshake_timeout)
self, protocol, sslcontext, waiter,
server_side, server_hostname,
ssl_handshake_timeout=ssl_handshake_timeout,
ssl_shutdown_timeout=ssl_shutdown_timeout
)
_SelectorSocketTransport(self, rawsock, ssl_protocol,
extra=extra, server=server)
return ssl_protocol._app_transport
Expand Down Expand Up @@ -146,15 +150,17 @@ def _write_to_self(self):

def _start_serving(self, protocol_factory, sock,
sslcontext=None, server=None, backlog=100,
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT,
ssl_shutdown_timeout=constants.SSL_SHUTDOWN_TIMEOUT):
self._add_reader(sock.fileno(), self._accept_connection,
protocol_factory, sock, sslcontext, server, backlog,
ssl_handshake_timeout)
ssl_handshake_timeout, ssl_shutdown_timeout)

def _accept_connection(
self, protocol_factory, sock,
sslcontext=None, server=None, backlog=100,
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT,
ssl_shutdown_timeout=constants.SSL_SHUTDOWN_TIMEOUT):
# This method is only called once for each event loop tick where the
# listening socket has triggered an EVENT_READ. There may be multiple
# connections waiting for an .accept() so it is called in a loop.
Expand Down Expand Up @@ -185,20 +191,22 @@ def _accept_connection(
self.call_later(constants.ACCEPT_RETRY_DELAY,
self._start_serving,
protocol_factory, sock, sslcontext, server,
backlog, ssl_handshake_timeout)
backlog, ssl_handshake_timeout,
ssl_shutdown_timeout)
else:
raise # The event loop will catch, log and ignore it.
else:
extra = {'peername': addr}
accept = self._accept_connection2(
protocol_factory, conn, extra, sslcontext, server,
ssl_handshake_timeout)
ssl_handshake_timeout, ssl_shutdown_timeout)
self.create_task(accept)

async def _accept_connection2(
self, protocol_factory, conn, extra,
sslcontext=None, server=None,
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT,
ssl_shutdown_timeout=constants.SSL_SHUTDOWN_TIMEOUT):
protocol = None
transport = None
try:
Expand All @@ -208,7 +216,8 @@ async def _accept_connection2(
transport = self._make_ssl_transport(
conn, protocol, sslcontext, waiter=waiter,
server_side=True, extra=extra, server=server,
ssl_handshake_timeout=ssl_handshake_timeout)
ssl_handshake_timeout=ssl_handshake_timeout,
ssl_shutdown_timeout=ssl_shutdown_timeout)
else:
transport = self._make_socket_transport(
conn, protocol, waiter=waiter, extra=extra,
Expand Down
Loading

0 comments on commit 5fb06ed

Please sign in to comment.