From afdef1bc585827d462ec1cfb848d259e54780123 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Cardona?= Date: Sat, 13 Oct 2018 11:41:51 +0200 Subject: [PATCH] bpo-34971: add support for TLS sessions from asyncio --- Lib/asyncio/base_events.py | 18 +++++-- Lib/asyncio/proactor_events.py | 6 ++- Lib/asyncio/selector_events.py | 6 ++- Lib/asyncio/sslproto.py | 15 ++++-- Lib/test/test_asyncio/test_base_events.py | 27 +++++++--- Lib/test/test_asyncio/test_events.py | 63 +++++++++++++++++++++++ 6 files changed, 115 insertions(+), 20 deletions(-) diff --git a/Lib/asyncio/base_events.py b/Lib/asyncio/base_events.py index 3726c556d4f09d..17ae6835e3bfd8 100644 --- a/Lib/asyncio/base_events.py +++ b/Lib/asyncio/base_events.py @@ -432,7 +432,8 @@ def _make_ssl_transport( *, server_side=False, server_hostname=None, extra=None, server=None, ssl_handshake_timeout=None, - call_connection_made=True): + call_connection_made=True, + ssl_session=None): """Create SSL transport.""" raise NotImplementedError @@ -866,7 +867,8 @@ async def create_connection( *, ssl=None, family=0, proto=0, flags=0, sock=None, local_addr=None, server_hostname=None, - ssl_handshake_timeout=None): + ssl_handshake_timeout=None, + ssl_session=None): """Connect to a TCP server. Create a streaming transport connection to a given Internet host and @@ -901,6 +903,9 @@ async def create_connection( raise ValueError( 'ssl_handshake_timeout is only meaningful with ssl') + if ssl_session is not None and not ssl: + raise ValueError('ssl_session is only meaningful with ssl') + if host is not None or port is not None: if sock is not None: raise ValueError( @@ -984,7 +989,8 @@ async def create_connection( transport, protocol = await self._create_connection_transport( sock, protocol_factory, ssl, server_hostname, - ssl_handshake_timeout=ssl_handshake_timeout) + ssl_handshake_timeout=ssl_handshake_timeout, + ssl_session=ssl_session) if self._debug: # Get the socket from the transport because SSL transport closes # the old socket and creates a new SSL socket @@ -996,7 +1002,8 @@ async def create_connection( async def _create_connection_transport( self, sock, protocol_factory, ssl, server_hostname, server_side=False, - ssl_handshake_timeout=None): + ssl_handshake_timeout=None, + ssl_session=None): sock.setblocking(False) @@ -1007,7 +1014,8 @@ async def _create_connection_transport( transport = self._make_ssl_transport( sock, protocol, sslcontext, waiter, server_side=server_side, server_hostname=server_hostname, - ssl_handshake_timeout=ssl_handshake_timeout) + ssl_handshake_timeout=ssl_handshake_timeout, + ssl_session=ssl_session) else: transport = self._make_socket_transport(sock, protocol, waiter) diff --git a/Lib/asyncio/proactor_events.py b/Lib/asyncio/proactor_events.py index ad23918802faad..f46692eee825a5 100644 --- a/Lib/asyncio/proactor_events.py +++ b/Lib/asyncio/proactor_events.py @@ -495,11 +495,13 @@ def _make_ssl_transport( self, rawsock, protocol, sslcontext, waiter=None, *, server_side=False, server_hostname=None, extra=None, server=None, - ssl_handshake_timeout=None): + ssl_handshake_timeout=None, + ssl_session=None): ssl_protocol = sslproto.SSLProtocol( self, protocol, sslcontext, waiter, server_side, server_hostname, - ssl_handshake_timeout=ssl_handshake_timeout) + ssl_handshake_timeout=ssl_handshake_timeout, + ssl_session=ssl_session) _ProactorSocketTransport(self, rawsock, ssl_protocol, extra=extra, server=server) return ssl_protocol._app_transport diff --git a/Lib/asyncio/selector_events.py b/Lib/asyncio/selector_events.py index 116c08d6ff7fdd..4924e88bb3d747 100644 --- a/Lib/asyncio/selector_events.py +++ b/Lib/asyncio/selector_events.py @@ -75,11 +75,13 @@ def _make_ssl_transport( self, rawsock, protocol, sslcontext, waiter=None, *, server_side=False, server_hostname=None, extra=None, server=None, - ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT): + ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT, + ssl_session=None): ssl_protocol = sslproto.SSLProtocol( self, protocol, sslcontext, waiter, server_side, server_hostname, - ssl_handshake_timeout=ssl_handshake_timeout) + ssl_handshake_timeout=ssl_handshake_timeout, + ssl_session=ssl_session) _SelectorSocketTransport(self, rawsock, ssl_protocol, extra=extra, server=server) return ssl_protocol._app_transport diff --git a/Lib/asyncio/sslproto.py b/Lib/asyncio/sslproto.py index 12fdb0d1c5ec10..b7d17b69c14079 100644 --- a/Lib/asyncio/sslproto.py +++ b/Lib/asyncio/sslproto.py @@ -53,7 +53,7 @@ class _SSLPipe(object): max_size = 256 * 1024 # Buffer size passed to read() - def __init__(self, context, server_side, server_hostname=None): + def __init__(self, context, server_side, server_hostname=None, session=None): """ The *context* argument specifies the ssl.SSLContext to use. @@ -67,6 +67,7 @@ def __init__(self, context, server_side, server_hostname=None): self._context = context self._server_side = server_side self._server_hostname = server_hostname + self._session = session self._state = _UNWRAPPED self._incoming = ssl.MemoryBIO() self._outgoing = ssl.MemoryBIO() @@ -117,7 +118,8 @@ def do_handshake(self, callback=None): self._sslobj = self._context.wrap_bio( self._incoming, self._outgoing, server_side=self._server_side, - server_hostname=self._server_hostname) + server_hostname=self._server_hostname, + session=self._session) self._state = _DO_HANDSHAKE self._handshake_cb = callback ssldata, appdata = self.feed_ssldata(b'', only_handshake=True) @@ -412,7 +414,8 @@ class SSLProtocol(protocols.Protocol): def __init__(self, loop, app_protocol, sslcontext, waiter, server_side=False, server_hostname=None, call_connection_made=True, - ssl_handshake_timeout=None): + ssl_handshake_timeout=None, + ssl_session=None): if ssl is None: raise RuntimeError('stdlib ssl module not available') @@ -433,9 +436,10 @@ def __init__(self, loop, app_protocol, sslcontext, waiter, else: self._server_hostname = None self._sslcontext = sslcontext + self._ssl_session = ssl_session # SSL-specific extra info. More info are set when the handshake # completes. - self._extra = dict(sslcontext=sslcontext) + self._extra = dict(sslcontext=sslcontext, ssl_session=ssl_session) # App data write buffering self._write_backlog = collections.deque() @@ -478,7 +482,8 @@ def connection_made(self, transport): self._transport = transport self._sslpipe = _SSLPipe(self._sslcontext, self._server_side, - self._server_hostname) + self._server_hostname, + self._ssl_session) self._start_handshake() def connection_lost(self, exc): diff --git a/Lib/test/test_asyncio/test_base_events.py b/Lib/test/test_asyncio/test_base_events.py index 6d544d1eda8635..da8b481b51d2b2 100644 --- a/Lib/test/test_asyncio/test_base_events.py +++ b/Lib/test/test_asyncio/test_base_events.py @@ -1418,44 +1418,51 @@ def mock_make_ssl_transport(sock, protocol, sslcontext, waiter, self.loop._make_ssl_transport.side_effect = mock_make_ssl_transport ANY = mock.ANY handshake_timeout = object() + session = object() # First try the default server_hostname. self.loop._make_ssl_transport.reset_mock() coro = self.loop.create_connection( MyProto, 'python.org', 80, ssl=True, - ssl_handshake_timeout=handshake_timeout) + ssl_handshake_timeout=handshake_timeout, + ssl_session=session) transport, _ = self.loop.run_until_complete(coro) transport.close() self.loop._make_ssl_transport.assert_called_with( ANY, ANY, ANY, ANY, server_side=False, server_hostname='python.org', - ssl_handshake_timeout=handshake_timeout) + ssl_handshake_timeout=handshake_timeout, + ssl_session=session) # Next try an explicit server_hostname. self.loop._make_ssl_transport.reset_mock() coro = self.loop.create_connection( MyProto, 'python.org', 80, ssl=True, server_hostname='perl.com', - ssl_handshake_timeout=handshake_timeout) + ssl_handshake_timeout=handshake_timeout, + ssl_session=session) transport, _ = self.loop.run_until_complete(coro) transport.close() self.loop._make_ssl_transport.assert_called_with( ANY, ANY, ANY, ANY, server_side=False, server_hostname='perl.com', - ssl_handshake_timeout=handshake_timeout) + ssl_handshake_timeout=handshake_timeout, + ssl_session=session) # Finally try an explicit empty server_hostname. self.loop._make_ssl_transport.reset_mock() coro = self.loop.create_connection( MyProto, 'python.org', 80, ssl=True, server_hostname='', - ssl_handshake_timeout=handshake_timeout) + ssl_handshake_timeout=handshake_timeout, + ssl_session=session) transport, _ = self.loop.run_until_complete(coro) transport.close() self.loop._make_ssl_transport.assert_called_with( ANY, ANY, ANY, ANY, server_side=False, server_hostname='', - ssl_handshake_timeout=handshake_timeout) + ssl_handshake_timeout=handshake_timeout, + ssl_session=session) def test_create_connection_no_ssl_server_hostname_errors(self): # When not using ssl, server_hostname must be None. @@ -1486,6 +1493,14 @@ def test_create_connection_ssl_timeout_for_plain_socket(self): 'ssl_handshake_timeout is only meaningful with ssl'): self.loop.run_until_complete(coro) + def test_create_connection_ssl_session_for_plain_socket(self): + coro = self.loop.create_connection( + MyProto, 'example.com', 80, ssl_session=object()) + with self.assertRaisesRegex( + ValueError, + 'ssl_session is only meaningful with ssl'): + self.loop.run_until_complete(coro) + def test_create_server_empty_host(self): # if host is empty string use None instead host = object() diff --git a/Lib/test/test_asyncio/test_events.py b/Lib/test/test_asyncio/test_events.py index b76cfb75cce26a..65067c6c9e719a 100644 --- a/Lib/test/test_asyncio/test_events.py +++ b/Lib/test/test_asyncio/test_events.py @@ -33,6 +33,7 @@ from asyncio import proactor_events from asyncio import selector_events from test.test_asyncio import utils as test_utils +from test.ssl_servers import make_https_server from test import support @@ -614,6 +615,68 @@ def test_create_ssl_connection(self): self._test_create_ssl_connection(httpd, create_connection, peername=httpd.address) + @unittest.skipIf(ssl is None, 'No ssl module') + def test_create_ssl_connection_with_session(self): + server_context = test_utils.simple_server_sslcontext() + server = make_https_server(self, context=server_context) + + client_context = test_utils.simple_client_sslcontext() + # TODO: sessions aren't compatible with TLSv1.3 yet + client_context.options |= ssl.OP_NO_TLSv1_3 + + def new_conn(*, session=None): + create_connection = functools.partial( + self.loop.create_connection, + lambda: MyProto(loop=self.loop), + 'localhost', server.port) + conn_fut = create_connection(ssl=client_context, ssl_session=session) + tr, pr = self.loop.run_until_complete(conn_fut) + self.loop.run_until_complete(pr.done) + sslobj = tr.get_extra_info('ssl_object') + stats = { + 'session': sslobj.session, + 'session_reused': sslobj.session_reused, + } + tr.close() + return stats + + # first connection without session + stats = new_conn() + session = stats['session'] + self.assertTrue(session.id) + self.assertGreater(session.time, 0) + self.assertGreater(session.timeout, 0) + self.assertTrue(session.has_ticket) + if ssl.OPENSSL_VERSION_INFO > (1, 0, 1): + self.assertGreater(session.ticket_lifetime_hint, 0) + self.assertFalse(stats['session_reused']) + + # reuse session + stats = new_conn(session=session) + self.assertTrue(stats['session_reused']) + session2 = stats['session'] + self.assertEqual(session2.id, session.id) + self.assertEqual(session2, session) + self.assertIsNot(session2, session) + self.assertGreaterEqual(session2.time, session.time) + self.assertGreaterEqual(session2.timeout, session.timeout) + + # another one without session + stats = new_conn() + self.assertFalse(stats['session_reused']) + session3 = stats['session'] + self.assertNotEqual(session3.id, session.id) + self.assertNotEqual(session3, session) + + # reuse session again + stats = new_conn(session=session) + self.assertTrue(stats['session_reused']) + session4 = stats['session'] + self.assertEqual(session4.id, session.id) + self.assertEqual(session4, session) + self.assertGreaterEqual(session4.time, session.time) + self.assertGreaterEqual(session4.timeout, session.timeout) + @support.skip_unless_bind_unix_socket @unittest.skipIf(ssl is None, 'No ssl module') def test_create_ssl_unix_connection(self):