From 1ba375b0b09b2951fe5c189750a5fa7a0e733b3a Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Sat, 11 May 2019 19:46:40 +0300 Subject: [PATCH 01/82] Merge StreamReader and StreamWriter --- Lib/asyncio/streams.py | 184 +++++++++++++++++++++++++---------------- 1 file changed, 111 insertions(+), 73 deletions(-) diff --git a/Lib/asyncio/streams.py b/Lib/asyncio/streams.py index d9a9f5e72d3b79..be879061c9630a 100644 --- a/Lib/asyncio/streams.py +++ b/Lib/asyncio/streams.py @@ -2,6 +2,7 @@ 'StreamReader', 'StreamWriter', 'StreamReaderProtocol', 'open_connection', 'start_server') +import enum import socket import sys import warnings @@ -22,6 +23,11 @@ _DEFAULT_LIMIT = 2 ** 16 # 64 KiB +class StreamKind(enum.IntEnum): + READ = 1 + WRITE = 2 + + async def open_connection(host=None, port=None, *, loop=None, limit=_DEFAULT_LIMIT, **kwds): """A wrapper for create_connection() returning a (reader, writer) pair. @@ -329,7 +335,7 @@ def __del__(self): closed.exception() -class StreamWriter: +class Stream: """Wraps a Transport. This exposes write(), writelines(), [can_]write_eof(), @@ -339,26 +345,66 @@ class StreamWriter: directly. """ - def __init__(self, transport, protocol, reader, loop, - *, _asyncio_internal=False): + _source_traceback = None + + def __init__(self, kind, *, + transport=None, + protocol=None, + loop=None, + limit=_DEFAULT_LIMIT, + _asyncio_internal=False): if not _asyncio_internal: warnings.warn(f"{self.__class__} should be instaniated " "by asyncio internals only, " "please avoid its creation from user code", DeprecationWarning) + self._kind = kind self._transport = transport self._protocol = protocol - # drain() expects that the reader has an exception() method - assert reader is None or isinstance(reader, StreamReader) - self._reader = reader - self._loop = loop + + # The line length limit is a security feature; + # it also doubles as half the buffer limit. + + if limit <= 0: + raise ValueError('Limit cannot be <= 0') + + self._limit = limit + if loop is None: + self._loop = events.get_event_loop() + else: + self._loop = loop + self._buffer = bytearray() + self._eof = False # Whether we're done. + self._waiter = None # A future used by _wait_for_data() + self._exception = None + self._transport = None + self._paused = False self._complete_fut = self._loop.create_future() self._complete_fut.set_result(None) + if self._loop.get_debug(): + self._source_traceback = format_helpers.extract_stack( + sys._getframe(1)) + def __repr__(self): - info = [self.__class__.__name__, f'transport={self._transport!r}'] - if self._reader is not None: - info.append(f'reader={self._reader!r}') + info = [self.__class__.__name__] + info.append(str(self._kind)) + if self._buffer: + info.append(f'{len(self._buffer)} bytes') + if self._eof: + info.append('eof') + if self._limit != _DEFAULT_LIMIT: + info.append(f'limit={self._limit}') + if self._waiter: + info.append(f'waiter={self._waiter!r}') + if self._exception: + info.append(f'exception={self._exception!r}') + if self._transport: + info.append(f'transport={self._transport!r}') + if self._paused: + info.append('paused') + if self._transport is not None: + info.append(f'transport={self._transport!r}') return '<{}>'.format(' '.join(info)) @property @@ -366,10 +412,14 @@ def transport(self): return self._transport def write(self, data): + if not self._kind & StreamKind.WRITE: + raise RuntimeError("The stream is read-only") self._transport.write(data) return self._fast_drain() def writelines(self, data): + if not self._kind & StreamKind.WRITE: + raise RuntimeError("The stream is read-only") self._transport.writelines(data) return self._fast_drain() @@ -377,13 +427,11 @@ def _fast_drain(self): # The helper tries to use fast-path to return already existing complete future # object if underlying transport is not paused and actual waiting for writing # resume is not needed - if self._reader is not None: - # this branch will be simplified after merging reader with writer - exc = self._reader.exception() - if exc is not None: - fut = self._loop.create_future() - fut.set_exception(exc) - return fut + exc = self.exception() + if exc is not None: + fut = self._loop.create_future() + fut.set_exception(exc) + return fut if not self._transport.is_closing(): if self._protocol._connection_lost: fut = self._loop.create_future() @@ -396,9 +444,13 @@ def _fast_drain(self): return self._loop.create_task(self.drain()) def write_eof(self): + if not self._kind & StreamKind.WRITE: + raise RuntimeError("The stream is read-only") return self._transport.write_eof() def can_write_eof(self): + if not self._kind & StreamKind.WRITE: + return False return self._transport.can_write_eof() def close(self): @@ -422,10 +474,11 @@ async def drain(self): w.write(data) await w.drain() """ - if self._reader is not None: - exc = self._reader.exception() - if exc is not None: - raise exc + if not self._kind & StreamKind.WRITE: + raise RuntimeError("The stream is read-only") + exc = self.exception() + if exc is not None: + raise exc if self._transport.is_closing(): # Wait for protocol.connection_lost() call # Raise connection closing error if any, @@ -435,58 +488,6 @@ async def drain(self): raise ConnectionResetError('Connection lost') await self._protocol._drain_helper() - -class StreamReader: - - _source_traceback = None - - def __init__(self, limit=_DEFAULT_LIMIT, loop=None, - *, _asyncio_internal=False): - if not _asyncio_internal: - warnings.warn(f"{self.__class__} should be instaniated " - "by asyncio internals only, " - "please avoid its creation from user code", - DeprecationWarning) - - # The line length limit is a security feature; - # it also doubles as half the buffer limit. - - if limit <= 0: - raise ValueError('Limit cannot be <= 0') - - self._limit = limit - if loop is None: - self._loop = events.get_event_loop() - else: - self._loop = loop - self._buffer = bytearray() - self._eof = False # Whether we're done. - self._waiter = None # A future used by _wait_for_data() - self._exception = None - self._transport = None - self._paused = False - if self._loop.get_debug(): - self._source_traceback = format_helpers.extract_stack( - sys._getframe(1)) - - def __repr__(self): - info = ['StreamReader'] - if self._buffer: - info.append(f'{len(self._buffer)} bytes') - if self._eof: - info.append('eof') - if self._limit != _DEFAULT_LIMIT: - info.append(f'limit={self._limit}') - if self._waiter: - info.append(f'waiter={self._waiter!r}') - if self._exception: - info.append(f'exception={self._exception!r}') - if self._transport: - info.append(f'transport={self._transport!r}') - if self._paused: - info.append('paused') - return '<{}>'.format(' '.join(info)) - def exception(self): return self._exception @@ -517,14 +518,20 @@ def _maybe_resume_transport(self): self._transport.resume_reading() def feed_eof(self): + if not self._kind & StreamKind.READ: + raise RuntimeError("The stream is write-only") self._eof = True self._wakeup_waiter() def at_eof(self): """Return True if the buffer is empty and 'feed_eof' was called.""" + if not self._kind & StreamKind.READ: + raise RuntimeError("The stream is write-only") return self._eof and not self._buffer def feed_data(self, data): + if not self._kind & StreamKind.READ: + raise RuntimeError("The stream is write-only") assert not self._eof, 'feed_data after feed_eof' if not data: @@ -590,6 +597,8 @@ async def readline(self): If stream was paused, this function will automatically resume it if needed. """ + if not self._kind & StreamKind.READ: + raise RuntimeError("The stream is write-only") sep = b'\n' seplen = len(sep) try: @@ -625,6 +634,8 @@ async def readuntil(self, separator=b'\n'): LimitOverrunError exception will be raised, and the data will be left in the internal buffer, so it can be read again. """ + if not self._kind & StreamKind.READ: + raise RuntimeError("The stream is write-only") seplen = len(separator) if seplen == 0: raise ValueError('Separator should be at least one-byte string') @@ -716,6 +727,8 @@ async def read(self, n=-1): If stream was paused, this function will automatically resume it if needed. """ + if not self._kind & StreamKind.READ: + raise RuntimeError("The stream is write-only") if self._exception is not None: raise self._exception @@ -761,6 +774,8 @@ async def readexactly(self, n): If stream was paused, this function will automatically resume it if needed. """ + if not self._kind & StreamKind.READ: + raise RuntimeError("The stream is write-only") if n < 0: raise ValueError('readexactly size can not be less than zero') @@ -788,6 +803,8 @@ async def readexactly(self, n): return data def __aiter__(self): + if not self._kind & StreamKind.READ: + raise RuntimeError("The stream is write-only") return self async def __anext__(self): @@ -795,3 +812,24 @@ async def __anext__(self): if val == b'': raise StopAsyncIteration return val + + +class StreamWriter(Stream): + def __init__(self, transport, protocol, reader, loop, + *, _asyncio_internal=False): + super().__init__(kind=StreamKind.WRITE, + transport=transpotr, + protocol=protocol, + loop=loop, + _asyncio_internal=_asyncio_internal + ) + + +class StreamReader(Stream): + def __init__(self, limit=_DEFAULT_LIMIT, loop=None, + *, _asyncio_internal=False): + super().__init__(kind=StreamKind.READ, + limit=limit, + loop=loop, + _asyncio_internal=_asyncio_internal, + ) From 1af22383494ade624a1e2a4dcfda1bd7ae0bae45 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Sat, 11 May 2019 19:46:59 +0300 Subject: [PATCH 02/82] Merge StreamReader and StreamWriter --- Lib/asyncio/streams.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Lib/asyncio/streams.py b/Lib/asyncio/streams.py index be879061c9630a..a8671c40478c0d 100644 --- a/Lib/asyncio/streams.py +++ b/Lib/asyncio/streams.py @@ -818,7 +818,7 @@ class StreamWriter(Stream): def __init__(self, transport, protocol, reader, loop, *, _asyncio_internal=False): super().__init__(kind=StreamKind.WRITE, - transport=transpotr, + transport=transport, protocol=protocol, loop=loop, _asyncio_internal=_asyncio_internal From d0543a522a691cd239bdb4a25750c10d185484eb Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Sat, 11 May 2019 20:00:38 +0300 Subject: [PATCH 03/82] Work on --- Lib/asyncio/streams.py | 63 ++++++++++++++------------- Lib/test/test_asyncio/test_streams.py | 17 ++++---- 2 files changed, 41 insertions(+), 39 deletions(-) diff --git a/Lib/asyncio/streams.py b/Lib/asyncio/streams.py index a8671c40478c0d..72f785f6f4934a 100644 --- a/Lib/asyncio/streams.py +++ b/Lib/asyncio/streams.py @@ -23,9 +23,24 @@ _DEFAULT_LIMIT = 2 ** 16 # 64 KiB -class StreamKind(enum.IntEnum): - READ = 1 - WRITE = 2 +class StreamKind(enum.Enum): + READ = "read" + WRITE = "write" + READWRITE = "readwrite" + + def is_read(self): + return self in (self.READ, self.READWRITE) + + def is_write(self): + return self in (self.WRITE, self.READWRITE) + + def check_read(self): + if not self.is_read(): + raise RuntimeError("The stream is read-only") + + def check_write(self): + if not self.is_write(): + raise RuntimeError("The stream is write-only") async def open_connection(host=None, port=None, *, @@ -388,7 +403,7 @@ def __init__(self, kind, *, def __repr__(self): info = [self.__class__.__name__] - info.append(str(self._kind)) + info.append(f'kind={self._kind}') if self._buffer: info.append(f'{len(self._buffer)} bytes') if self._eof: @@ -403,8 +418,6 @@ def __repr__(self): info.append(f'transport={self._transport!r}') if self._paused: info.append('paused') - if self._transport is not None: - info.append(f'transport={self._transport!r}') return '<{}>'.format(' '.join(info)) @property @@ -412,14 +425,12 @@ def transport(self): return self._transport def write(self, data): - if not self._kind & StreamKind.WRITE: - raise RuntimeError("The stream is read-only") + self._kind.check_write() self._transport.write(data) return self._fast_drain() def writelines(self, data): - if not self._kind & StreamKind.WRITE: - raise RuntimeError("The stream is read-only") + self._kind.check_write() self._transport.writelines(data) return self._fast_drain() @@ -444,12 +455,11 @@ def _fast_drain(self): return self._loop.create_task(self.drain()) def write_eof(self): - if not self._kind & StreamKind.WRITE: - raise RuntimeError("The stream is read-only") + self._kind.check_write() return self._transport.write_eof() def can_write_eof(self): - if not self._kind & StreamKind.WRITE: + if not self._kind.is_write(): return False return self._transport.can_write_eof() @@ -474,8 +484,7 @@ async def drain(self): w.write(data) await w.drain() """ - if not self._kind & StreamKind.WRITE: - raise RuntimeError("The stream is read-only") + self._kind.check_read() exc = self.exception() if exc is not None: raise exc @@ -518,20 +527,17 @@ def _maybe_resume_transport(self): self._transport.resume_reading() def feed_eof(self): - if not self._kind & StreamKind.READ: - raise RuntimeError("The stream is write-only") + self._kind.check_read() self._eof = True self._wakeup_waiter() def at_eof(self): """Return True if the buffer is empty and 'feed_eof' was called.""" - if not self._kind & StreamKind.READ: - raise RuntimeError("The stream is write-only") + self._kind.check_read() return self._eof and not self._buffer def feed_data(self, data): - if not self._kind & StreamKind.READ: - raise RuntimeError("The stream is write-only") + self._kind.check_read() assert not self._eof, 'feed_data after feed_eof' if not data: @@ -597,8 +603,7 @@ async def readline(self): If stream was paused, this function will automatically resume it if needed. """ - if not self._kind & StreamKind.READ: - raise RuntimeError("The stream is write-only") + self._kind.check_read() sep = b'\n' seplen = len(sep) try: @@ -634,8 +639,7 @@ async def readuntil(self, separator=b'\n'): LimitOverrunError exception will be raised, and the data will be left in the internal buffer, so it can be read again. """ - if not self._kind & StreamKind.READ: - raise RuntimeError("The stream is write-only") + self._kind.check_read() seplen = len(separator) if seplen == 0: raise ValueError('Separator should be at least one-byte string') @@ -727,8 +731,7 @@ async def read(self, n=-1): If stream was paused, this function will automatically resume it if needed. """ - if not self._kind & StreamKind.READ: - raise RuntimeError("The stream is write-only") + self._kind.check_read() if self._exception is not None: raise self._exception @@ -774,8 +777,7 @@ async def readexactly(self, n): If stream was paused, this function will automatically resume it if needed. """ - if not self._kind & StreamKind.READ: - raise RuntimeError("The stream is write-only") + self._kind.check_read() if n < 0: raise ValueError('readexactly size can not be less than zero') @@ -803,8 +805,7 @@ async def readexactly(self, n): return data def __aiter__(self): - if not self._kind & StreamKind.READ: - raise RuntimeError("The stream is write-only") + self._kind.check_read() return self async def __anext__(self): diff --git a/Lib/test/test_asyncio/test_streams.py b/Lib/test/test_asyncio/test_streams.py index bf93f30e1aafb6..7f1fd142096e8e 100644 --- a/Lib/test/test_asyncio/test_streams.py +++ b/Lib/test/test_asyncio/test_streams.py @@ -887,31 +887,31 @@ async def client(host, port): def test___repr__(self): stream = asyncio.StreamReader(loop=self.loop, _asyncio_internal=True) - self.assertEqual("", repr(stream)) + self.assertEqual("", repr(stream)) def test___repr__nondefault_limit(self): stream = asyncio.StreamReader(loop=self.loop, limit=123, _asyncio_internal=True) - self.assertEqual("", repr(stream)) + self.assertEqual("", repr(stream)) def test___repr__eof(self): stream = asyncio.StreamReader(loop=self.loop, _asyncio_internal=True) stream.feed_eof() - self.assertEqual("", repr(stream)) + self.assertEqual("", repr(stream)) def test___repr__data(self): stream = asyncio.StreamReader(loop=self.loop, _asyncio_internal=True) stream.feed_data(b'data') - self.assertEqual("", repr(stream)) + self.assertEqual("", repr(stream)) def test___repr__exception(self): stream = asyncio.StreamReader(loop=self.loop, _asyncio_internal=True) exc = RuntimeError() stream.set_exception(exc) - self.assertEqual("", + self.assertEqual("", repr(stream)) def test___repr__waiter(self): @@ -920,11 +920,11 @@ def test___repr__waiter(self): stream._waiter = asyncio.Future(loop=self.loop) self.assertRegex( repr(stream), - r">") + r">") stream._waiter.set_result(None) self.loop.run_until_complete(stream._waiter) stream._waiter = None - self.assertEqual("", repr(stream)) + self.assertEqual("", repr(stream)) def test___repr__transport(self): stream = asyncio.StreamReader(loop=self.loop, @@ -932,7 +932,8 @@ def test___repr__transport(self): stream._transport = mock.Mock() stream._transport.__repr__ = mock.Mock() stream._transport.__repr__.return_value = "" - self.assertEqual(">", repr(stream)) + self.assertEqual(">", + repr(stream)) def test_IncompleteReadError_pickleable(self): e = asyncio.IncompleteReadError(b'abc', 10) From b3cf11c275bd51b2e40cb400660af430336fa5af Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Sat, 11 May 2019 20:33:01 +0300 Subject: [PATCH 04/82] Work on --- Lib/asyncio/streams.py | 1 - 1 file changed, 1 deletion(-) diff --git a/Lib/asyncio/streams.py b/Lib/asyncio/streams.py index 72f785f6f4934a..30eff09e79e20c 100644 --- a/Lib/asyncio/streams.py +++ b/Lib/asyncio/streams.py @@ -392,7 +392,6 @@ def __init__(self, kind, *, self._eof = False # Whether we're done. self._waiter = None # A future used by _wait_for_data() self._exception = None - self._transport = None self._paused = False self._complete_fut = self._loop.create_future() self._complete_fut.set_result(None) From b059fd9032c3af8ad04e5f5d7dbd1f9343c033a5 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Sat, 11 May 2019 20:36:04 +0300 Subject: [PATCH 05/82] ... --- Lib/asyncio/streams.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/Lib/asyncio/streams.py b/Lib/asyncio/streams.py index 30eff09e79e20c..9686d53bcaf705 100644 --- a/Lib/asyncio/streams.py +++ b/Lib/asyncio/streams.py @@ -374,6 +374,9 @@ def __init__(self, kind, *, "please avoid its creation from user code", DeprecationWarning) self._kind = kind + if kind.is_write(): + assert transport is not None + assert protocol is not None self._transport = transport self._protocol = protocol From d5c33912645b52806e2110c9710bc7c02b83f6bc Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Mon, 13 May 2019 13:37:34 +0300 Subject: [PATCH 06/82] ... --- Lib/asyncio/streams.py | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/Lib/asyncio/streams.py b/Lib/asyncio/streams.py index 9686d53bcaf705..61cf99bf88cc5e 100644 --- a/Lib/asyncio/streams.py +++ b/Lib/asyncio/streams.py @@ -63,16 +63,15 @@ async def open_connection(host=None, port=None, *, really nothing special here except some convenience.) """ if loop is None: - loop = events.get_event_loop() - reader = StreamReader(limit=limit, loop=loop, - _asyncio_internal=True) - protocol = StreamReaderProtocol(reader, loop=loop, - _asyncio_internal=True) + loop = events.get_running_loop() + stream = Stream(kind=StreamKind.READWRITE, + limit=limit, loop=loop, + _asyncio_internal=True) transport, _ = await loop.create_connection( - lambda: protocol, host, port, **kwds) - writer = StreamWriter(transport, protocol, reader, loop, - _asyncio_internal=True) - return reader, writer + lambda: StreamReaderProtocol(stream, loop=loop, + _asyncio_internal=True), + host, port, **kwds) + return stream, stream async def start_server(client_connected_cb, host=None, port=None, *, @@ -294,6 +293,7 @@ def connection_made(self, transport): reader = self._stream_reader if reader is not None: reader.set_transport(transport) + reader._protocol = self self._over_ssl = transport.get_extra_info('sslcontext') is not None if self._client_connected_cb is not None: self._stream_writer = StreamWriter(transport, self, @@ -374,9 +374,6 @@ def __init__(self, kind, *, "please avoid its creation from user code", DeprecationWarning) self._kind = kind - if kind.is_write(): - assert transport is not None - assert protocol is not None self._transport = transport self._protocol = protocol From abce39c3223f550b91fab71913dcbf491d67a3ad Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Mon, 13 May 2019 14:24:32 +0300 Subject: [PATCH 07/82] ../.. --- Lib/asyncio/streams.py | 109 ++++++++++++++++++----------------------- 1 file changed, 49 insertions(+), 60 deletions(-) diff --git a/Lib/asyncio/streams.py b/Lib/asyncio/streams.py index 61cf99bf88cc5e..246b3b12a273d3 100644 --- a/Lib/asyncio/streams.py +++ b/Lib/asyncio/streams.py @@ -63,9 +63,10 @@ async def open_connection(host=None, port=None, *, really nothing special here except some convenience.) """ if loop is None: - loop = events.get_running_loop() + loop = events.get_event_loop() stream = Stream(kind=StreamKind.READWRITE, - limit=limit, loop=loop, + limit=limit, + loop=loop, _asyncio_internal=True) transport, _ = await loop.create_connection( lambda: StreamReaderProtocol(stream, loop=loop, @@ -101,9 +102,11 @@ async def start_server(client_connected_cb, host=None, port=None, *, loop = events.get_event_loop() def factory(): - reader = StreamReader(limit=limit, loop=loop, - _asyncio_internal=True) - protocol = StreamReaderProtocol(reader, client_connected_cb, + stream = Stream(kind=StreamKind.READWRITE, + limit=limit, + loop=loop, + _asyncio_internal=True) + protocol = StreamReaderProtocol(stream, client_connected_cb, loop=loop, _asyncio_internal=True) return protocol @@ -119,15 +122,16 @@ async def open_unix_connection(path=None, *, """Similar to `open_connection` but works with UNIX Domain Sockets.""" if loop is None: loop = events.get_event_loop() - reader = StreamReader(limit=limit, loop=loop, - _asyncio_internal=True) - protocol = StreamReaderProtocol(reader, loop=loop, - _asyncio_internal=True) + stream = Stream(kind=StreamKind.READWRITE, + limit=limit, + loop=loop, + _asyncio_internal=True) transport, _ = await loop.create_unix_connection( - lambda: protocol, path, **kwds) - writer = StreamWriter(transport, protocol, reader, loop, - _asyncio_internal=True) - return reader, writer + lambda: StreamReaderProtocol(stream, + loop=loop, + _asyncio_internal=True), + path, **kwds) + return stream, stream async def start_unix_server(client_connected_cb, path=None, *, loop=None, limit=_DEFAULT_LIMIT, **kwds): @@ -136,9 +140,12 @@ async def start_unix_server(client_connected_cb, path=None, *, loop = events.get_event_loop() def factory(): - reader = StreamReader(limit=limit, loop=loop, - _asyncio_internal=True) - protocol = StreamReaderProtocol(reader, client_connected_cb, + stream = Stream(kind=StreamKind.READWRITE, + limit=limit, + loop=loop, + _asyncio_internal=True) + protocol = StreamReaderProtocol(stream, + client_connected_cb, loop=loop, _asyncio_internal=True) return protocol @@ -234,28 +241,19 @@ class StreamReaderProtocol(FlowControlMixin, protocols.Protocol): _source_traceback = None - def __init__(self, stream_reader, client_connected_cb=None, loop=None, + def __init__(self, stream, client_connected_cb=None, loop=None, *, _asyncio_internal=False): super().__init__(loop=loop, _asyncio_internal=_asyncio_internal) - if stream_reader is not None: - self._stream_reader_wr = weakref.ref(stream_reader, - self._on_reader_gc) - self._source_traceback = stream_reader._source_traceback - else: - self._stream_reader_wr = None - if client_connected_cb is not None: - # This is a stream created by the `create_server()` function. - # Keep a strong reference to the reader until a connection - # is established. - self._strong_reader = stream_reader + self._stream_wr = weakref.ref(stream, + self._on_gc) + self._source_traceback = stream._source_traceback self._reject_connection = False - self._stream_writer = None self._transport = None self._client_connected_cb = client_connected_cb self._over_ssl = False self._closed = self._loop.create_future() - def _on_reader_gc(self, wr): + def _on_gc(self, wr): transport = self._transport if transport is not None: # connection_made was called @@ -269,13 +267,12 @@ def _on_reader_gc(self, wr): transport.abort() else: self._reject_connection = True - self._stream_reader_wr = None + self._stream_wr = None @property - def _stream_reader(self): - if self._stream_reader_wr is None: - return None - return self._stream_reader_wr() + def _stream(self): + assert self._stream_wr is not None + return self._stream_wr() def connection_made(self, transport): if self._reject_connection: @@ -290,48 +287,40 @@ def connection_made(self, transport): transport.abort() return self._transport = transport - reader = self._stream_reader - if reader is not None: - reader.set_transport(transport) - reader._protocol = self + self._stream.set_transport(transport) + self._stream._protocol = self self._over_ssl = transport.get_extra_info('sslcontext') is not None if self._client_connected_cb is not None: - self._stream_writer = StreamWriter(transport, self, - reader, - self._loop, - _asyncio_internal=True) - res = self._client_connected_cb(reader, - self._stream_writer) + res = self._client_connected_cb(self._stream, + self._stream) if coroutines.iscoroutine(res): self._loop.create_task(res) - self._strong_reader = None def connection_lost(self, exc): - reader = self._stream_reader - if reader is not None: - if exc is None: - reader.feed_eof() - else: - reader.set_exception(exc) + if self._reject_connection: + return + if exc is None: + self._stream.feed_eof() + else: + self._stream.set_exception(exc) if not self._closed.done(): if exc is None: self._closed.set_result(None) else: self._closed.set_exception(exc) super().connection_lost(exc) - self._stream_reader_wr = None - self._stream_writer = None + self._stream_wr = None self._transport = None def data_received(self, data): - reader = self._stream_reader - if reader is not None: - reader.feed_data(data) + if self._reject_connection: + return + self._stream.feed_data(data) def eof_received(self): - reader = self._stream_reader - if reader is not None: - reader.feed_eof() + if self._reject_connection: + return + self._stream.feed_eof() if self._over_ssl: # Prevent a warning in SSLProtocol.eof_received: # "returning true from eof_received() From 0c8e8af247e92629a2b47d95db4f6e1b0059c698 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Mon, 13 May 2019 14:57:46 +0300 Subject: [PATCH 08/82] Fix tests --- Lib/asyncio/streams.py | 47 +++++++++++++++++---------- Lib/test/test_asyncio/test_streams.py | 4 +-- 2 files changed, 31 insertions(+), 20 deletions(-) diff --git a/Lib/asyncio/streams.py b/Lib/asyncio/streams.py index 246b3b12a273d3..c42e48f3dc027a 100644 --- a/Lib/asyncio/streams.py +++ b/Lib/asyncio/streams.py @@ -244,8 +244,13 @@ class StreamReaderProtocol(FlowControlMixin, protocols.Protocol): def __init__(self, stream, client_connected_cb=None, loop=None, *, _asyncio_internal=False): super().__init__(loop=loop, _asyncio_internal=_asyncio_internal) - self._stream_wr = weakref.ref(stream, - self._on_gc) + if client_connected_cb: + self._strong_stream = stream + self._stream_wr = None + else: + self._strong_stream = None + self._stream_wr = weakref.ref(stream, + self._on_gc) self._source_traceback = stream._source_traceback self._reject_connection = False self._transport = None @@ -271,7 +276,10 @@ def _on_gc(self, wr): @property def _stream(self): - assert self._stream_wr is not None + if self._strong_stream is not None: + return self._strong_stream + if self._stream_wr is None: + return None return self._stream_wr() def connection_made(self, transport): @@ -287,8 +295,11 @@ def connection_made(self, transport): transport.abort() return self._transport = transport - self._stream.set_transport(transport) - self._stream._protocol = self + stream = self._stream + if stream is None: + return + stream.set_transport(transport) + stream._protocol = self self._over_ssl = transport.get_extra_info('sslcontext') is not None if self._client_connected_cb is not None: res = self._client_connected_cb(self._stream, @@ -297,12 +308,12 @@ def connection_made(self, transport): self._loop.create_task(res) def connection_lost(self, exc): - if self._reject_connection: - return - if exc is None: - self._stream.feed_eof() - else: - self._stream.set_exception(exc) + stream = self._stream + if stream is not None: + if exc is None: + stream.feed_eof() + else: + stream.set_exception(exc) if not self._closed.done(): if exc is None: self._closed.set_result(None) @@ -313,14 +324,14 @@ def connection_lost(self, exc): self._transport = None def data_received(self, data): - if self._reject_connection: - return - self._stream.feed_data(data) + stream = self._stream + if stream is not None: + stream.feed_data(data) def eof_received(self): - if self._reject_connection: - return - self._stream.feed_eof() + stream = self._stream + if stream is not None: + stream.feed_eof() if self._over_ssl: # Prevent a warning in SSLProtocol.eof_received: # "returning true from eof_received() @@ -472,7 +483,7 @@ async def drain(self): w.write(data) await w.drain() """ - self._kind.check_read() + self._kind.check_write() exc = self.exception() if exc is not None: raise exc diff --git a/Lib/test/test_asyncio/test_streams.py b/Lib/test/test_asyncio/test_streams.py index 7f1fd142096e8e..8fdcb61f0410d8 100644 --- a/Lib/test/test_asyncio/test_streams.py +++ b/Lib/test/test_asyncio/test_streams.py @@ -1003,10 +1003,10 @@ def test_del_stream_before_sock_closing(self): # make a chance to close the socket test_utils.run_briefly(self.loop) - self.assertEqual(1, len(messages)) + self.assertEqual(1, len(messages), messages) self.assertEqual(sock.fileno(), -1) - self.assertEqual(1, len(messages)) + self.assertEqual(1, len(messages), messages) self.assertEqual('An open stream object is being garbage ' 'collected; call "stream.close()" explicitly.', messages[0]['message']) From e0a54d03043d2efc11891ba163d310b08d6485de Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Mon, 13 May 2019 15:01:50 +0300 Subject: [PATCH 09/82] kind -> mode --- Lib/asyncio/streams.py | 52 +++++++++++++++++++++++------------------- 1 file changed, 29 insertions(+), 23 deletions(-) diff --git a/Lib/asyncio/streams.py b/Lib/asyncio/streams.py index c42e48f3dc027a..2eb12b4fb0a889 100644 --- a/Lib/asyncio/streams.py +++ b/Lib/asyncio/streams.py @@ -23,7 +23,7 @@ _DEFAULT_LIMIT = 2 ** 16 # 64 KiB -class StreamKind(enum.Enum): +class StreamMode(enum.Enum): READ = "read" WRITE = "write" READWRITE = "readwrite" @@ -64,7 +64,7 @@ async def open_connection(host=None, port=None, *, """ if loop is None: loop = events.get_event_loop() - stream = Stream(kind=StreamKind.READWRITE, + stream = Stream(mode=StreamMode.READWRITE, limit=limit, loop=loop, _asyncio_internal=True) @@ -102,7 +102,7 @@ async def start_server(client_connected_cb, host=None, port=None, *, loop = events.get_event_loop() def factory(): - stream = Stream(kind=StreamKind.READWRITE, + stream = Stream(mode=StreamMode.READWRITE, limit=limit, loop=loop, _asyncio_internal=True) @@ -122,7 +122,7 @@ async def open_unix_connection(path=None, *, """Similar to `open_connection` but works with UNIX Domain Sockets.""" if loop is None: loop = events.get_event_loop() - stream = Stream(kind=StreamKind.READWRITE, + stream = Stream(mode=StreamMode.READWRITE, limit=limit, loop=loop, _asyncio_internal=True) @@ -140,7 +140,7 @@ async def start_unix_server(client_connected_cb, path=None, *, loop = events.get_event_loop() def factory(): - stream = Stream(kind=StreamKind.READWRITE, + stream = Stream(mode=StreamMode.READWRITE, limit=limit, loop=loop, _asyncio_internal=True) @@ -362,7 +362,7 @@ class Stream: _source_traceback = None - def __init__(self, kind, *, + def __init__(self, mode, *, transport=None, protocol=None, loop=None, @@ -373,7 +373,7 @@ def __init__(self, kind, *, "by asyncio internals only, " "please avoid its creation from user code", DeprecationWarning) - self._kind = kind + self._mode = mode self._transport = transport self._protocol = protocol @@ -402,7 +402,7 @@ def __init__(self, kind, *, def __repr__(self): info = [self.__class__.__name__] - info.append(f'kind={self._kind}') + info.append(f'mode={self._mode}') if self._buffer: info.append(f'{len(self._buffer)} bytes') if self._eof: @@ -424,12 +424,12 @@ def transport(self): return self._transport def write(self, data): - self._kind.check_write() + self._mode.check_write() self._transport.write(data) return self._fast_drain() def writelines(self, data): - self._kind.check_write() + self._mode.check_write() self._transport.writelines(data) return self._fast_drain() @@ -454,11 +454,11 @@ def _fast_drain(self): return self._loop.create_task(self.drain()) def write_eof(self): - self._kind.check_write() + self._mode.check_write() return self._transport.write_eof() def can_write_eof(self): - if not self._kind.is_write(): + if not self._mode.is_write(): return False return self._transport.can_write_eof() @@ -483,7 +483,7 @@ async def drain(self): w.write(data) await w.drain() """ - self._kind.check_write() + self._mode.check_write() exc = self.exception() if exc is not None: raise exc @@ -526,17 +526,17 @@ def _maybe_resume_transport(self): self._transport.resume_reading() def feed_eof(self): - self._kind.check_read() + self._mode.check_read() self._eof = True self._wakeup_waiter() def at_eof(self): """Return True if the buffer is empty and 'feed_eof' was called.""" - self._kind.check_read() + self._mode.check_read() return self._eof and not self._buffer def feed_data(self, data): - self._kind.check_read() + self._mode.check_read() assert not self._eof, 'feed_data after feed_eof' if not data: @@ -602,7 +602,7 @@ async def readline(self): If stream was paused, this function will automatically resume it if needed. """ - self._kind.check_read() + self._mode.check_read() sep = b'\n' seplen = len(sep) try: @@ -638,7 +638,7 @@ async def readuntil(self, separator=b'\n'): LimitOverrunError exception will be raised, and the data will be left in the internal buffer, so it can be read again. """ - self._kind.check_read() + self._mode.check_read() seplen = len(separator) if seplen == 0: raise ValueError('Separator should be at least one-byte string') @@ -730,7 +730,7 @@ async def read(self, n=-1): If stream was paused, this function will automatically resume it if needed. """ - self._kind.check_read() + self._mode.check_read() if self._exception is not None: raise self._exception @@ -776,7 +776,7 @@ async def readexactly(self, n): If stream was paused, this function will automatically resume it if needed. """ - self._kind.check_read() + self._mode.check_read() if n < 0: raise ValueError('readexactly size can not be less than zero') @@ -804,7 +804,7 @@ async def readexactly(self, n): return data def __aiter__(self): - self._kind.check_read() + self._mode.check_read() return self async def __anext__(self): @@ -817,7 +817,10 @@ async def __anext__(self): class StreamWriter(Stream): def __init__(self, transport, protocol, reader, loop, *, _asyncio_internal=False): - super().__init__(kind=StreamKind.WRITE, + warnings.warn("StreamReader class is deprecated in favor of Stream", + DeprecationWarning, + stacklevel=2) + super().__init__(mode=StreamMode.WRITE, transport=transport, protocol=protocol, loop=loop, @@ -828,7 +831,10 @@ def __init__(self, transport, protocol, reader, loop, class StreamReader(Stream): def __init__(self, limit=_DEFAULT_LIMIT, loop=None, *, _asyncio_internal=False): - super().__init__(kind=StreamKind.READ, + warnings.warn("StreamWriter class is deprecated in favor of Stream", + DeprecationWarning, + stacklevel=2) + super().__init__(mode=StreamMode.READ, limit=limit, loop=loop, _asyncio_internal=_asyncio_internal, From a632339e80b02d719e7f7f06014430c84d07b2c7 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Mon, 13 May 2019 15:12:53 +0300 Subject: [PATCH 10/82] Fix streams --- Lib/asyncio/streams.py | 1 + Lib/test/test_asyncio/test_streams.py | 14 +++++++------- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/Lib/asyncio/streams.py b/Lib/asyncio/streams.py index 2eb12b4fb0a889..9ae5b352d5e981 100644 --- a/Lib/asyncio/streams.py +++ b/Lib/asyncio/streams.py @@ -1,5 +1,6 @@ __all__ = ( 'StreamReader', 'StreamWriter', 'StreamReaderProtocol', + 'Stream', 'StreamMode', 'open_connection', 'start_server') import enum diff --git a/Lib/test/test_asyncio/test_streams.py b/Lib/test/test_asyncio/test_streams.py index 8fdcb61f0410d8..37d7ce39c07929 100644 --- a/Lib/test/test_asyncio/test_streams.py +++ b/Lib/test/test_asyncio/test_streams.py @@ -887,31 +887,31 @@ async def client(host, port): def test___repr__(self): stream = asyncio.StreamReader(loop=self.loop, _asyncio_internal=True) - self.assertEqual("", repr(stream)) + self.assertEqual("", repr(stream)) def test___repr__nondefault_limit(self): stream = asyncio.StreamReader(loop=self.loop, limit=123, _asyncio_internal=True) - self.assertEqual("", repr(stream)) + self.assertEqual("", repr(stream)) def test___repr__eof(self): stream = asyncio.StreamReader(loop=self.loop, _asyncio_internal=True) stream.feed_eof() - self.assertEqual("", repr(stream)) + self.assertEqual("", repr(stream)) def test___repr__data(self): stream = asyncio.StreamReader(loop=self.loop, _asyncio_internal=True) stream.feed_data(b'data') - self.assertEqual("", repr(stream)) + self.assertEqual("", repr(stream)) def test___repr__exception(self): stream = asyncio.StreamReader(loop=self.loop, _asyncio_internal=True) exc = RuntimeError() stream.set_exception(exc) - self.assertEqual("", + self.assertEqual("", repr(stream)) def test___repr__waiter(self): @@ -924,7 +924,7 @@ def test___repr__waiter(self): stream._waiter.set_result(None) self.loop.run_until_complete(stream._waiter) stream._waiter = None - self.assertEqual("", repr(stream)) + self.assertEqual("", repr(stream)) def test___repr__transport(self): stream = asyncio.StreamReader(loop=self.loop, @@ -932,7 +932,7 @@ def test___repr__transport(self): stream._transport = mock.Mock() stream._transport.__repr__ = mock.Mock() stream._transport.__repr__.return_value = "" - self.assertEqual(">", + self.assertEqual(">", repr(stream)) def test_IncompleteReadError_pickleable(self): From 97641eb0982e9f47f425b72320490c025f1f0c76 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Mon, 13 May 2019 15:20:43 +0300 Subject: [PATCH 11/82] Convert subprocess to use Stream instead of StreamReader/StreamWriter --- Lib/asyncio/streams.py | 2 ++ Lib/asyncio/subprocess.py | 29 +++++++++++++++++------------ 2 files changed, 19 insertions(+), 12 deletions(-) diff --git a/Lib/asyncio/streams.py b/Lib/asyncio/streams.py index 9ae5b352d5e981..2b187c012ac75f 100644 --- a/Lib/asyncio/streams.py +++ b/Lib/asyncio/streams.py @@ -518,6 +518,8 @@ def _wakeup_waiter(self): waiter.set_result(None) def set_transport(self, transport): + if transport is self._transport: + return assert self._transport is None, 'Transport already set' self._transport = transport diff --git a/Lib/asyncio/subprocess.py b/Lib/asyncio/subprocess.py index d34b6118fdcf72..d7a37efa2e0acc 100644 --- a/Lib/asyncio/subprocess.py +++ b/Lib/asyncio/subprocess.py @@ -40,30 +40,35 @@ def __repr__(self): def connection_made(self, transport): self._transport = transport - stdout_transport = transport.get_pipe_transport(1) if stdout_transport is not None: - self.stdout = streams.StreamReader(limit=self._limit, - loop=self._loop, - _asyncio_internal=True) + self.stdout = streams.Stream(mode=streams.StreamMode.READ, + transport=stdout_transport, + protocol=self, + limit=self._limit, + loop=self._loop, + _asyncio_internal=True) self.stdout.set_transport(stdout_transport) self._pipe_fds.append(1) stderr_transport = transport.get_pipe_transport(2) if stderr_transport is not None: - self.stderr = streams.StreamReader(limit=self._limit, - loop=self._loop, - _asyncio_internal=True) + self.stderr = streams.Stream(mode=streams.StreamMode.READ, + transport=stderr_transport, + protocol=self, + limit=self._limit, + loop=self._loop, + _asyncio_internal=True) self.stderr.set_transport(stderr_transport) self._pipe_fds.append(2) stdin_transport = transport.get_pipe_transport(0) if stdin_transport is not None: - self.stdin = streams.StreamWriter(stdin_transport, - protocol=self, - reader=None, - loop=self._loop, - _asyncio_internal=True) + self.stdin = streams.Stream(mode=streams.StreamMode.WRITE, + transport=stdin_transport, + protocol=self, + loop=self._loop, + _asyncio_internal=True) def pipe_data_received(self, fd, data): if fd == 1: From 962495b478dab9dbfde66b20805af1d72d892b50 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Mon, 13 May 2019 16:45:04 +0300 Subject: [PATCH 12/82] Reimplement deprecation strategy --- Lib/asyncio/__init__.py | 27 +++ Lib/asyncio/streams.py | 28 --- Lib/test/test_asyncio/test_pep492.py | 4 +- Lib/test/test_asyncio/test_streams.py | 267 +++++++++++++++----------- 4 files changed, 184 insertions(+), 142 deletions(-) diff --git a/Lib/asyncio/__init__.py b/Lib/asyncio/__init__.py index 28c2e2c429f34a..296e7d7209443c 100644 --- a/Lib/asyncio/__init__.py +++ b/Lib/asyncio/__init__.py @@ -3,6 +3,7 @@ # flake8: noqa import sys +import warnings # This relies on each of the submodules having an __all__ variable. from .base_events import * @@ -43,3 +44,29 @@ else: from .unix_events import * # pragma: no cover __all__ += unix_events.__all__ + + +__all__ += ('StreamReader', 'StreamWriter', 'StreamReaderProtocol') # deprecated + + +def __getattr__(name): + if name == 'StreamReader': + warnings.warn("StreamReader is deprecated, use asyncio.Stream instead", + DeprecationWarning, + stacklevel=2) + return Stream + if name == 'StreamWriter': + warnings.warn("StreamWriter is deprecated, use asyncio.Stream instead", + DeprecationWarning, + stacklevel=2) + return Stream + + if name == 'StreamReaderProtocol': + warnings.warn("StreamReaderProtocol is a private API, " + "don't use the class in user code", + DeprecationWarning, + stacklevel=2) + from .streams import StreamReaderProtocol + return StreamReaderProtocol + + raise AttributeError(f"module {__name__} has no attribute {name}") diff --git a/Lib/asyncio/streams.py b/Lib/asyncio/streams.py index 2b187c012ac75f..c5d44591379d5d 100644 --- a/Lib/asyncio/streams.py +++ b/Lib/asyncio/streams.py @@ -1,5 +1,4 @@ __all__ = ( - 'StreamReader', 'StreamWriter', 'StreamReaderProtocol', 'Stream', 'StreamMode', 'open_connection', 'start_server') @@ -815,30 +814,3 @@ async def __anext__(self): if val == b'': raise StopAsyncIteration return val - - -class StreamWriter(Stream): - def __init__(self, transport, protocol, reader, loop, - *, _asyncio_internal=False): - warnings.warn("StreamReader class is deprecated in favor of Stream", - DeprecationWarning, - stacklevel=2) - super().__init__(mode=StreamMode.WRITE, - transport=transport, - protocol=protocol, - loop=loop, - _asyncio_internal=_asyncio_internal - ) - - -class StreamReader(Stream): - def __init__(self, limit=_DEFAULT_LIMIT, loop=None, - *, _asyncio_internal=False): - warnings.warn("StreamWriter class is deprecated in favor of Stream", - DeprecationWarning, - stacklevel=2) - super().__init__(mode=StreamMode.READ, - limit=limit, - loop=loop, - _asyncio_internal=_asyncio_internal, - ) diff --git a/Lib/test/test_asyncio/test_pep492.py b/Lib/test/test_asyncio/test_pep492.py index 558e268415cd00..53d159adade79d 100644 --- a/Lib/test/test_asyncio/test_pep492.py +++ b/Lib/test/test_asyncio/test_pep492.py @@ -94,7 +94,9 @@ class StreamReaderTests(BaseTest): def test_readline(self): DATA = b'line1\nline2\nline3' - stream = asyncio.StreamReader(loop=self.loop, _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) stream.feed_data(DATA) stream.feed_eof() diff --git a/Lib/test/test_asyncio/test_streams.py b/Lib/test/test_asyncio/test_streams.py index 37d7ce39c07929..c36f229f5cd379 100644 --- a/Lib/test/test_asyncio/test_streams.py +++ b/Lib/test/test_asyncio/test_streams.py @@ -16,6 +16,7 @@ ssl = None import asyncio +from asyncio.streams import StreamReaderProtocol from test.test_asyncio import utils as test_utils @@ -42,7 +43,7 @@ def tearDown(self): @mock.patch('asyncio.streams.events') def test_ctor_global_loop(self, m_events): - stream = asyncio.StreamReader(_asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, _asyncio_internal=True) self.assertIs(stream._loop, m_events.get_event_loop.return_value) def _basetest_open_connection(self, open_connection_fut): @@ -158,23 +159,26 @@ def test_open_unix_connection_error(self): self._basetest_open_connection_error(conn_fut) def test_feed_empty_data(self): - stream = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) stream.feed_data(b'') self.assertEqual(b'', stream._buffer) def test_feed_nonempty_data(self): - stream = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) stream.feed_data(self.DATA) self.assertEqual(self.DATA, stream._buffer) def test_read_zero(self): # Read zero bytes. - stream = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) stream.feed_data(self.DATA) data = self.loop.run_until_complete(stream.read(0)) @@ -183,8 +187,9 @@ def test_read_zero(self): def test_read(self): # Read bytes. - stream = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) read_task = asyncio.Task(stream.read(30), loop=self.loop) def cb(): @@ -197,8 +202,9 @@ def cb(): def test_read_line_breaks(self): # Read bytes without line breaks. - stream = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) stream.feed_data(b'line1') stream.feed_data(b'line2') @@ -209,8 +215,9 @@ def test_read_line_breaks(self): def test_read_eof(self): # Read bytes, stop at eof. - stream = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) read_task = asyncio.Task(stream.read(1024), loop=self.loop) def cb(): @@ -223,8 +230,9 @@ def cb(): def test_read_until_eof(self): # Read all bytes until eof. - stream = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) read_task = asyncio.Task(stream.read(-1), loop=self.loop) def cb(): @@ -239,8 +247,9 @@ def cb(): self.assertEqual(b'', stream._buffer) def test_read_exception(self): - stream = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) stream.feed_data(b'line\n') data = self.loop.run_until_complete(stream.read(2)) @@ -252,16 +261,19 @@ def test_read_exception(self): def test_invalid_limit(self): with self.assertRaisesRegex(ValueError, 'imit'): - asyncio.StreamReader(limit=0, loop=self.loop, - _asyncio_internal=True) + asyncio.Stream(mode=asyncio.StreamMode.READ, + limit=0, loop=self.loop, + _asyncio_internal=True) with self.assertRaisesRegex(ValueError, 'imit'): - asyncio.StreamReader(limit=-1, loop=self.loop, - _asyncio_internal=True) + asyncio.Stream(mode=asyncio.StreamMode.READ, + limit=-1, loop=self.loop, + _asyncio_internal=True) def test_read_limit(self): - stream = asyncio.StreamReader(limit=3, loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + limit=3, loop=self.loop, + _asyncio_internal=True) stream.feed_data(b'chunk') data = self.loop.run_until_complete(stream.read(5)) self.assertEqual(b'chunk', data) @@ -270,8 +282,9 @@ def test_read_limit(self): def test_readline(self): # Read one line. 'readline' will need to wait for the data # to come from 'cb' - stream = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) stream.feed_data(b'chunk1 ') read_task = asyncio.Task(stream.readline(), loop=self.loop) @@ -286,11 +299,12 @@ def cb(): self.assertEqual(b' chunk4', stream._buffer) def test_readline_limit_with_existing_data(self): - # Read one line. The data is in StreamReader's buffer + # Read one line. The data is in Stream's buffer # before the event loop is run. - stream = asyncio.StreamReader(limit=3, loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + limit=3, loop=self.loop, + _asyncio_internal=True) stream.feed_data(b'li') stream.feed_data(b'ne1\nline2\n') @@ -299,8 +313,9 @@ def test_readline_limit_with_existing_data(self): # The buffer should contain the remaining data after exception self.assertEqual(b'line2\n', stream._buffer) - stream = asyncio.StreamReader(limit=3, loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + limit=3, loop=self.loop, + _asyncio_internal=True) stream.feed_data(b'li') stream.feed_data(b'ne1') stream.feed_data(b'li') @@ -315,8 +330,9 @@ def test_readline_limit_with_existing_data(self): self.assertEqual(b'', stream._buffer) def test_at_eof(self): - stream = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) self.assertFalse(stream.at_eof()) stream.feed_data(b'some data\n') @@ -334,8 +350,9 @@ def test_readline_limit(self): # Read one line. StreamReaders are fed with data after # their 'readline' methods are called. - stream = asyncio.StreamReader(limit=7, loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + limit=7, loop=self.loop, + _asyncio_internal=True) def cb(): stream.feed_data(b'chunk1') stream.feed_data(b'chunk2') @@ -349,8 +366,9 @@ def cb(): # a ValueError it should be empty. self.assertEqual(b'', stream._buffer) - stream = asyncio.StreamReader(limit=7, loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + limit=7, loop=self.loop, + _asyncio_internal=True) def cb(): stream.feed_data(b'chunk1') stream.feed_data(b'chunk2\n') @@ -363,8 +381,9 @@ def cb(): self.assertEqual(b'chunk3\n', stream._buffer) # check strictness of the limit - stream = asyncio.StreamReader(limit=7, loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + limit=7, loop=self.loop, + _asyncio_internal=True) stream.feed_data(b'1234567\n') line = self.loop.run_until_complete(stream.readline()) self.assertEqual(b'1234567\n', line) @@ -383,8 +402,9 @@ def cb(): def test_readline_nolimit_nowait(self): # All needed data for the first 'readline' call will be # in the buffer. - stream = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) stream.feed_data(self.DATA[:6]) stream.feed_data(self.DATA[6:]) @@ -394,8 +414,9 @@ def test_readline_nolimit_nowait(self): self.assertEqual(b'line2\nline3\n', stream._buffer) def test_readline_eof(self): - stream = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) stream.feed_data(b'some data') stream.feed_eof() @@ -403,16 +424,18 @@ def test_readline_eof(self): self.assertEqual(b'some data', line) def test_readline_empty_eof(self): - stream = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) stream.feed_eof() line = self.loop.run_until_complete(stream.readline()) self.assertEqual(b'', line) def test_readline_read_byte_count(self): - stream = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) stream.feed_data(self.DATA) self.loop.run_until_complete(stream.readline()) @@ -423,8 +446,9 @@ def test_readline_read_byte_count(self): self.assertEqual(b'ine3\n', stream._buffer) def test_readline_exception(self): - stream = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) stream.feed_data(b'line\n') data = self.loop.run_until_complete(stream.readline()) @@ -436,14 +460,16 @@ def test_readline_exception(self): self.assertEqual(b'', stream._buffer) def test_readuntil_separator(self): - stream = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) with self.assertRaisesRegex(ValueError, 'Separator should be'): self.loop.run_until_complete(stream.readuntil(separator=b'')) def test_readuntil_multi_chunks(self): - stream = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) stream.feed_data(b'lineAAA') data = self.loop.run_until_complete(stream.readuntil(separator=b'AAA')) @@ -461,8 +487,9 @@ def test_readuntil_multi_chunks(self): self.assertEqual(b'xxx', stream._buffer) def test_readuntil_multi_chunks_1(self): - stream = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) stream.feed_data(b'QWEaa') stream.feed_data(b'XYaa') @@ -497,8 +524,9 @@ def test_readuntil_multi_chunks_1(self): self.assertEqual(b'', stream._buffer) def test_readuntil_eof(self): - stream = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) stream.feed_data(b'some dataAA') stream.feed_eof() @@ -509,8 +537,9 @@ def test_readuntil_eof(self): self.assertEqual(b'', stream._buffer) def test_readuntil_limit_found_sep(self): - stream = asyncio.StreamReader(loop=self.loop, limit=3, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, limit=3, + _asyncio_internal=True) stream.feed_data(b'some dataAA') with self.assertRaisesRegex(asyncio.LimitOverrunError, @@ -528,8 +557,9 @@ def test_readuntil_limit_found_sep(self): def test_readexactly_zero_or_less(self): # Read exact number of bytes (zero or less). - stream = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) stream.feed_data(self.DATA) data = self.loop.run_until_complete(stream.readexactly(0)) @@ -542,8 +572,9 @@ def test_readexactly_zero_or_less(self): def test_readexactly(self): # Read exact number of bytes. - stream = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) n = 2 * len(self.DATA) read_task = asyncio.Task(stream.readexactly(n), loop=self.loop) @@ -559,8 +590,9 @@ def cb(): self.assertEqual(self.DATA, stream._buffer) def test_readexactly_limit(self): - stream = asyncio.StreamReader(limit=3, loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + limit=3, loop=self.loop, + _asyncio_internal=True) stream.feed_data(b'chunk') data = self.loop.run_until_complete(stream.readexactly(5)) self.assertEqual(b'chunk', data) @@ -568,8 +600,9 @@ def test_readexactly_limit(self): def test_readexactly_eof(self): # Read exact number of bytes (eof). - stream = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) n = 2 * len(self.DATA) read_task = asyncio.Task(stream.readexactly(n), loop=self.loop) @@ -587,8 +620,9 @@ def cb(): self.assertEqual(b'', stream._buffer) def test_readexactly_exception(self): - stream = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) stream.feed_data(b'line\n') data = self.loop.run_until_complete(stream.readexactly(2)) @@ -599,8 +633,9 @@ def test_readexactly_exception(self): ValueError, self.loop.run_until_complete, stream.readexactly(2)) def test_exception(self): - stream = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) self.assertIsNone(stream.exception()) exc = ValueError() @@ -608,8 +643,9 @@ def test_exception(self): self.assertIs(stream.exception(), exc) def test_exception_waiter(self): - stream = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) @asyncio.coroutine def set_err(): @@ -623,8 +659,9 @@ def set_err(): self.assertRaises(ValueError, t1.result) def test_exception_cancel(self): - stream = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) t = asyncio.Task(stream.readline(), loop=self.loop) test_utils.run_briefly(self.loop) @@ -787,7 +824,7 @@ async def client(path): def test_read_all_from_pipe_reader(self): # See asyncio issue 168. This test is derived from the example # subprocess_attach_read_pipe.py, but we configure the - # StreamReader's limit so that twice it is less than the size + # Stream's limit so that twice it is less than the size # of the data writter. Also we must explicitly attach a child # watcher to the event loop. @@ -801,10 +838,11 @@ def test_read_all_from_pipe_reader(self): args = [sys.executable, '-c', code, str(wfd)] pipe = open(rfd, 'rb', 0) - reader = asyncio.StreamReader(loop=self.loop, limit=1, - _asyncio_internal=True) - protocol = asyncio.StreamReaderProtocol(reader, loop=self.loop, - _asyncio_internal=True) + reader = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, limit=1, + _asyncio_internal=True) + protocol = StreamReaderProtocol(reader, loop=self.loop, + _asyncio_internal=True) transport, _ = self.loop.run_until_complete( self.loop.connect_read_pipe(lambda: protocol, pipe)) @@ -830,7 +868,8 @@ def test_streamreader_constructor(self): # asyncio issue #184: Ensure that StreamReaderProtocol constructor # retrieves the current loop if the loop parameter is not set - reader = asyncio.StreamReader(_asyncio_internal=True) + reader = asyncio.Stream(mode=asyncio.StreamMode.READ, + _asyncio_internal=True) self.assertIs(reader._loop, self.loop) def test_streamreaderprotocol_constructor(self): @@ -840,7 +879,7 @@ def test_streamreaderprotocol_constructor(self): # asyncio issue #184: Ensure that StreamReaderProtocol constructor # retrieves the current loop if the loop parameter is not set reader = mock.Mock() - protocol = asyncio.StreamReaderProtocol(reader, _asyncio_internal=True) + protocol = StreamReaderProtocol(reader, _asyncio_internal=True) self.assertIs(protocol._loop, self.loop) def test_drain_raises(self): @@ -885,54 +924,61 @@ async def client(host, port): thread.join() def test___repr__(self): - stream = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) - self.assertEqual("", repr(stream)) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) + self.assertEqual("", repr(stream)) def test___repr__nondefault_limit(self): - stream = asyncio.StreamReader(loop=self.loop, limit=123, + stream = asyncio.StreamReader(mode=asyncio.StreamMode.READ, + loop=self.loop, limit=123, _asyncio_internal=True) - self.assertEqual("", repr(stream)) + self.assertEqual("", repr(stream)) def test___repr__eof(self): - stream = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) stream.feed_eof() - self.assertEqual("", repr(stream)) + self.assertEqual("", repr(stream)) def test___repr__data(self): - stream = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) stream.feed_data(b'data') - self.assertEqual("", repr(stream)) + self.assertEqual("", repr(stream)) def test___repr__exception(self): - stream = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) exc = RuntimeError() stream.set_exception(exc) - self.assertEqual("", + self.assertEqual("", repr(stream)) def test___repr__waiter(self): - stream = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) stream._waiter = asyncio.Future(loop=self.loop) self.assertRegex( repr(stream), - r">") + r">") stream._waiter.set_result(None) self.loop.run_until_complete(stream._waiter) stream._waiter = None self.assertEqual("", repr(stream)) def test___repr__transport(self): - stream = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) stream._transport = mock.Mock() stream._transport.__repr__ = mock.Mock() stream._transport.__repr__.return_value = "" - self.assertEqual(">", + self.assertEqual(">", repr(stream)) def test_IncompleteReadError_pickleable(self): @@ -1016,10 +1062,11 @@ def test_del_stream_before_connection_made(self): self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) with test_utils.run_test_server() as httpd: - rd = asyncio.StreamReader(loop=self.loop, + rd = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) + pr = StreamReaderProtocol(rd, loop=self.loop, _asyncio_internal=True) - pr = asyncio.StreamReaderProtocol(rd, loop=self.loop, - _asyncio_internal=True) del rd gc.collect() tr, _ = self.loop.run_until_complete( @@ -1096,21 +1143,15 @@ def test_eof_feed_when_closing_writer(self): def test_stream_reader_create_warning(self): with self.assertWarns(DeprecationWarning): - asyncio.StreamReader(loop=self.loop) + asyncio.StreamReader def test_stream_reader_protocol_create_warning(self): - reader = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) with self.assertWarns(DeprecationWarning): - asyncio.StreamReaderProtocol(reader, loop=self.loop) + asyncio.StreamReaderProtocol def test_stream_writer_create_warning(self): - reader = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) - proto = asyncio.StreamReaderProtocol(reader, loop=self.loop, - _asyncio_internal=True) with self.assertWarns(DeprecationWarning): - asyncio.StreamWriter('transport', proto, reader, self.loop) + asyncio.StreamWriter From 0ec17eda79246227368b568282c1c3e38cf69c7a Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Mon, 13 May 2019 17:56:29 +0300 Subject: [PATCH 13/82] Disallow asyncio.StreamReaderProtocol usage outside of asyncio package --- Lib/asyncio/__init__.py | 8 ++------ Lib/test/test_asyncio/test_streams.py | 12 ++++++------ 2 files changed, 8 insertions(+), 12 deletions(-) diff --git a/Lib/asyncio/__init__.py b/Lib/asyncio/__init__.py index 296e7d7209443c..513e4fcd670232 100644 --- a/Lib/asyncio/__init__.py +++ b/Lib/asyncio/__init__.py @@ -62,11 +62,7 @@ def __getattr__(name): return Stream if name == 'StreamReaderProtocol': - warnings.warn("StreamReaderProtocol is a private API, " - "don't use the class in user code", - DeprecationWarning, - stacklevel=2) - from .streams import StreamReaderProtocol - return StreamReaderProtocol + raise AttributeError("StreamReaderProtocol is a private API, " + "don't use the class in user code") raise AttributeError(f"module {__name__} has no attribute {name}") diff --git a/Lib/test/test_asyncio/test_streams.py b/Lib/test/test_asyncio/test_streams.py index c36f229f5cd379..f8cfa7208a47d8 100644 --- a/Lib/test/test_asyncio/test_streams.py +++ b/Lib/test/test_asyncio/test_streams.py @@ -347,7 +347,7 @@ def test_at_eof(self): self.assertTrue(stream.at_eof()) def test_readline_limit(self): - # Read one line. StreamReaders are fed with data after + # Read one line. Streams are fed with data after # their 'readline' methods are called. stream = asyncio.Stream(mode=asyncio.StreamMode.READ, @@ -930,9 +930,9 @@ def test___repr__(self): self.assertEqual("", repr(stream)) def test___repr__nondefault_limit(self): - stream = asyncio.StreamReader(mode=asyncio.StreamMode.READ, - loop=self.loop, limit=123, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, limit=123, + _asyncio_internal=True) self.assertEqual("", repr(stream)) def test___repr__eof(self): @@ -969,7 +969,7 @@ def test___repr__waiter(self): stream._waiter.set_result(None) self.loop.run_until_complete(stream._waiter) stream._waiter = None - self.assertEqual("", repr(stream)) + self.assertEqual("", repr(stream)) def test___repr__transport(self): stream = asyncio.Stream(mode=asyncio.StreamMode.READ, @@ -1146,7 +1146,7 @@ def test_stream_reader_create_warning(self): asyncio.StreamReader def test_stream_reader_protocol_create_warning(self): - with self.assertWarns(DeprecationWarning): + with self.assertRaises(AttributeError): asyncio.StreamReaderProtocol def test_stream_writer_create_warning(self): From 77bf0e45903de8a6ce0a8e716f5aa41fa1d12602 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Mon, 13 May 2019 18:36:10 +0300 Subject: [PATCH 14/82] Fix test___all__ --- Lib/asyncio/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Lib/asyncio/__init__.py b/Lib/asyncio/__init__.py index 513e4fcd670232..b7a11aeac697f7 100644 --- a/Lib/asyncio/__init__.py +++ b/Lib/asyncio/__init__.py @@ -46,7 +46,7 @@ __all__ += unix_events.__all__ -__all__ += ('StreamReader', 'StreamWriter', 'StreamReaderProtocol') # deprecated +__all__ += ('StreamReader', 'StreamWriter') # deprecated def __getattr__(name): From f46854f2e9628b96043abc5d8b0469dfebc27c21 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Mon, 13 May 2019 18:36:52 +0300 Subject: [PATCH 15/82] Don't expose asyncio.StreamReaderProtocol --- Lib/asyncio/__init__.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/Lib/asyncio/__init__.py b/Lib/asyncio/__init__.py index b7a11aeac697f7..8b986e86d17590 100644 --- a/Lib/asyncio/__init__.py +++ b/Lib/asyncio/__init__.py @@ -61,8 +61,4 @@ def __getattr__(name): stacklevel=2) return Stream - if name == 'StreamReaderProtocol': - raise AttributeError("StreamReaderProtocol is a private API, " - "don't use the class in user code") - raise AttributeError(f"module {__name__} has no attribute {name}") From 13f1430f65cebfd75df5fa27c95920099ee77538 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Mon, 13 May 2019 19:26:15 +0300 Subject: [PATCH 16/82] Ignore warning registry in test___all__ --- Lib/test/test___all__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/Lib/test/test___all__.py b/Lib/test/test___all__.py index f6e82eb64ab025..116abca6c8a629 100644 --- a/Lib/test/test___all__.py +++ b/Lib/test/test___all__.py @@ -40,6 +40,8 @@ def check_all(self, modname): del names["__builtins__"] if '__annotations__' in names: del names['__annotations__'] + if "__warningregistry__" in names: + del names["__warningregistry__"] keys = set(names) all_list = sys.modules[modname].__all__ all_set = set(all_list) From a1bf0b7f479aeb06bc88416822f61b9eef71ca66 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Mon, 13 May 2019 20:26:04 +0300 Subject: [PATCH 17/82] Support close on stdout/stderr --- Lib/asyncio/streams.py | 3 +-- Lib/asyncio/subprocess.py | 6 ++++++ 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/Lib/asyncio/streams.py b/Lib/asyncio/streams.py index c5d44591379d5d..eff41837f21a92 100644 --- a/Lib/asyncio/streams.py +++ b/Lib/asyncio/streams.py @@ -249,8 +249,7 @@ def __init__(self, stream, client_connected_cb=None, loop=None, self._stream_wr = None else: self._strong_stream = None - self._stream_wr = weakref.ref(stream, - self._on_gc) + self._stream_wr = weakref.ref(stream, self._on_gc) self._source_traceback = stream._source_traceback self._reject_connection = False self._transport = None diff --git a/Lib/asyncio/subprocess.py b/Lib/asyncio/subprocess.py index d7a37efa2e0acc..e6bec71d6c7dac 100644 --- a/Lib/asyncio/subprocess.py +++ b/Lib/asyncio/subprocess.py @@ -27,6 +27,8 @@ def __init__(self, limit, loop, *, _asyncio_internal=False): self._process_exited = False self._pipe_fds = [] self._stdin_closed = self._loop.create_future() + self._stdout_closed = self._loop.create_future() + self._stderr_closed = self._loop.create_future() def __repr__(self): info = [self.__class__.__name__] @@ -119,6 +121,10 @@ def _maybe_close_transport(self): def _get_close_waiter(self, stream): if stream is self.stdin: return self._stdin_closed + elif stream is self.stdout: + return self._stdout_closed + elif stream is self.stderr: + return self._stderr_closed class Process: From 062b570d5b232720cfec1eddaed23612bad0c9b3 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Mon, 13 May 2019 21:30:41 +0300 Subject: [PATCH 18/82] Fix Windows --- Lib/test/test_asyncio/test_windows_events.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/Lib/test/test_asyncio/test_windows_events.py b/Lib/test/test_asyncio/test_windows_events.py index 05f85159be0cd5..24730fc5e8f24f 100644 --- a/Lib/test/test_asyncio/test_windows_events.py +++ b/Lib/test/test_asyncio/test_windows_events.py @@ -17,6 +17,7 @@ import asyncio from asyncio import windows_events +from asyncio.streams import StreamReaderProtocol from test.test_asyncio import utils as test_utils from test.support.script_helper import spawn_python @@ -100,8 +101,9 @@ async def _test_pipe(self): clients = [] for i in range(5): - stream_reader = asyncio.StreamReader(loop=self.loop) - protocol = asyncio.StreamReaderProtocol(stream_reader, + stream = asyncio.StreamReader(mode=asyncio.StreamMode.READ, + loop=self.loop) + protocol = asyncio.StreamReaderProtocol(stream, loop=self.loop) trans, proto = await self.loop.create_pipe_connection( lambda: protocol, ADDRESS) From ec24c7024cb4387c071704787f6e8898c28ddfc9 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Mon, 13 May 2019 22:06:02 +0300 Subject: [PATCH 19/82] make patchcheck --- Lib/asyncio/streams.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Lib/asyncio/streams.py b/Lib/asyncio/streams.py index eff41837f21a92..9238028dc7ce5c 100644 --- a/Lib/asyncio/streams.py +++ b/Lib/asyncio/streams.py @@ -411,7 +411,7 @@ def __repr__(self): if self._waiter: info.append(f'waiter={self._waiter!r}') if self._exception: - info.append(f'exception={self._exception!r}') + info.append(f'exception={self._exception!r}') if self._transport: info.append(f'transport={self._transport!r}') if self._paused: From a45bdca30a2c376cbe1e6c239ae20c6dfb43d112 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Tue, 14 May 2019 00:43:02 +0300 Subject: [PATCH 20/82] Fix import names --- Lib/test/test_asyncio/test_windows_events.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/Lib/test/test_asyncio/test_windows_events.py b/Lib/test/test_asyncio/test_windows_events.py index 24730fc5e8f24f..d014807633599c 100644 --- a/Lib/test/test_asyncio/test_windows_events.py +++ b/Lib/test/test_asyncio/test_windows_events.py @@ -101,10 +101,10 @@ async def _test_pipe(self): clients = [] for i in range(5): - stream = asyncio.StreamReader(mode=asyncio.StreamMode.READ, - loop=self.loop) - protocol = asyncio.StreamReaderProtocol(stream, - loop=self.loop) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop) + protocol = StreamReaderProtocol(stream, + loop=self.loop) trans, proto = await self.loop.create_pipe_connection( lambda: protocol, ADDRESS) self.assertIsInstance(trans, asyncio.Transport) From 5cc3107463ff0f20c355b09a62d2fafcdba80dcf Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Tue, 14 May 2019 10:28:00 +0300 Subject: [PATCH 21/82] Fix typo --- Lib/test/test_asyncio/test_windows_events.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Lib/test/test_asyncio/test_windows_events.py b/Lib/test/test_asyncio/test_windows_events.py index d014807633599c..c776fcf1e117d7 100644 --- a/Lib/test/test_asyncio/test_windows_events.py +++ b/Lib/test/test_asyncio/test_windows_events.py @@ -109,7 +109,7 @@ async def _test_pipe(self): lambda: protocol, ADDRESS) self.assertIsInstance(trans, asyncio.Transport) self.assertEqual(protocol, proto) - clients.append((stream_reader, trans)) + clients.append((stream, trans)) for i, (r, w) in enumerate(clients): w.write('lower-{}\n'.format(i).encode()) From 1d2cf877563a59c9f4afe40d225d8d184a13f202 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Tue, 14 May 2019 13:41:51 +0300 Subject: [PATCH 22/82] Add NEWS --- .../next/Library/2019-05-14-12-25-44.bpo-36889.MChPqP.rst | 2 ++ 1 file changed, 2 insertions(+) create mode 100644 Misc/NEWS.d/next/Library/2019-05-14-12-25-44.bpo-36889.MChPqP.rst diff --git a/Misc/NEWS.d/next/Library/2019-05-14-12-25-44.bpo-36889.MChPqP.rst b/Misc/NEWS.d/next/Library/2019-05-14-12-25-44.bpo-36889.MChPqP.rst new file mode 100644 index 00000000000000..1009eb6e841669 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2019-05-14-12-25-44.bpo-36889.MChPqP.rst @@ -0,0 +1,2 @@ +Merge asyncio.StreamReader and asyncio.StreamWriter into asyncio.Stream +class with readonly, writeonly and readwrite modes. From d2f91c1a88212d35913c153ec23ab8c91b5f7a76 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Tue, 14 May 2019 13:42:16 +0300 Subject: [PATCH 23/82] Split StreamReaderProtocol into _StreamProtocol and _ServerStreamProtocol --- Lib/asyncio/streams.py | 148 +++++++++++-------- Lib/test/test_asyncio/test_streams.py | 34 ++--- Lib/test/test_asyncio/test_windows_events.py | 6 +- 3 files changed, 102 insertions(+), 86 deletions(-) diff --git a/Lib/asyncio/streams.py b/Lib/asyncio/streams.py index 9238028dc7ce5c..6f3cd98a999634 100644 --- a/Lib/asyncio/streams.py +++ b/Lib/asyncio/streams.py @@ -69,8 +69,8 @@ async def open_connection(host=None, port=None, *, loop=loop, _asyncio_internal=True) transport, _ = await loop.create_connection( - lambda: StreamReaderProtocol(stream, loop=loop, - _asyncio_internal=True), + lambda: _StreamProtocol(stream, loop=loop, + _asyncio_internal=True), host, port, **kwds) return stream, stream @@ -106,9 +106,9 @@ def factory(): limit=limit, loop=loop, _asyncio_internal=True) - protocol = StreamReaderProtocol(stream, client_connected_cb, - loop=loop, - _asyncio_internal=True) + protocol = _ServerStreamProtocol(stream, client_connected_cb, + loop=loop, + _asyncio_internal=True) return protocol return await loop.create_server(factory, host, port, **kwds) @@ -127,9 +127,9 @@ async def open_unix_connection(path=None, *, loop=loop, _asyncio_internal=True) transport, _ = await loop.create_unix_connection( - lambda: StreamReaderProtocol(stream, - loop=loop, - _asyncio_internal=True), + lambda: _StreamProtocol(stream, + loop=loop, + _asyncio_internal=True), path, **kwds) return stream, stream @@ -144,10 +144,10 @@ def factory(): limit=limit, loop=loop, _asyncio_internal=True) - protocol = StreamReaderProtocol(stream, - client_connected_cb, - loop=loop, - _asyncio_internal=True) + protocol = _ServerStreamProtocol(stream, + client_connected_cb, + loop=loop, + _asyncio_internal=True) return protocol return await loop.create_unix_server(factory, path, **kwds) @@ -230,7 +230,7 @@ def _get_close_waiter(self, stream): raise NotImplementedError -class StreamReaderProtocol(FlowControlMixin, protocols.Protocol): +class _BaseStreamProtocol(FlowControlMixin, protocols.Protocol): """Helper class to adapt between Protocol and StreamReader. (This is a helper class instead of making StreamReader itself a @@ -240,59 +240,16 @@ class StreamReaderProtocol(FlowControlMixin, protocols.Protocol): """ _source_traceback = None + _stream = None # initialized in derived classes - def __init__(self, stream, client_connected_cb=None, loop=None, + def __init__(self, loop=None, *, _asyncio_internal=False): super().__init__(loop=loop, _asyncio_internal=_asyncio_internal) - if client_connected_cb: - self._strong_stream = stream - self._stream_wr = None - else: - self._strong_stream = None - self._stream_wr = weakref.ref(stream, self._on_gc) - self._source_traceback = stream._source_traceback - self._reject_connection = False self._transport = None - self._client_connected_cb = client_connected_cb self._over_ssl = False self._closed = self._loop.create_future() - def _on_gc(self, wr): - transport = self._transport - if transport is not None: - # connection_made was called - context = { - 'message': ('An open stream object is being garbage ' - 'collected; call "stream.close()" explicitly.') - } - if self._source_traceback: - context['source_traceback'] = self._source_traceback - self._loop.call_exception_handler(context) - transport.abort() - else: - self._reject_connection = True - self._stream_wr = None - - @property - def _stream(self): - if self._strong_stream is not None: - return self._strong_stream - if self._stream_wr is None: - return None - return self._stream_wr() - def connection_made(self, transport): - if self._reject_connection: - context = { - 'message': ('An open stream was garbage collected prior to ' - 'establishing network connection; ' - 'call "stream.close()" explicitly.') - } - if self._source_traceback: - context['source_traceback'] = self._source_traceback - self._loop.call_exception_handler(context) - transport.abort() - return self._transport = transport stream = self._stream if stream is None: @@ -300,11 +257,6 @@ def connection_made(self, transport): stream.set_transport(transport) stream._protocol = self self._over_ssl = transport.get_extra_info('sslcontext') is not None - if self._client_connected_cb is not None: - res = self._client_connected_cb(self._stream, - self._stream) - if coroutines.iscoroutine(res): - self._loop.create_task(res) def connection_lost(self, exc): stream = self._stream @@ -319,7 +271,6 @@ def connection_lost(self, exc): else: self._closed.set_exception(exc) super().connection_lost(exc) - self._stream_wr = None self._transport = None def data_received(self, data): @@ -349,6 +300,75 @@ def __del__(self): closed.exception() +class _StreamProtocol(_BaseStreamProtocol): + def __init__(self, stream, loop=None, + *, _asyncio_internal=False): + super().__init__(loop=loop, _asyncio_internal=_asyncio_internal) + self._source_traceback = stream._source_traceback + self._stream_wr = weakref.ref(stream, self._on_gc) + self._reject_connection = False + + def _on_gc(self, wr): + transport = self._transport + if transport is not None: + # connection_made was called + context = { + 'message': ('An open stream object is being garbage ' + 'collected; call "stream.close()" explicitly.') + } + if self._source_traceback: + context['source_traceback'] = self._source_traceback + self._loop.call_exception_handler(context) + transport.abort() + else: + self._reject_connection = True + self._stream_wr = None + + @property + def _stream(self): + if self._stream_wr is None: + return None + return self._stream_wr() + + def connection_made(self, transport): + if self._reject_connection: + context = { + 'message': ('An open stream was garbage collected prior to ' + 'establishing network connection; ' + 'call "stream.close()" explicitly.') + } + if self._source_traceback: + context['source_traceback'] = self._source_traceback + self._loop.call_exception_handler(context) + transport.abort() + return + super().connection_made(transport) + + def connection_lost(self, exc): + super().connection_lost(exc) + self._stream_wr = None + + +class _ServerStreamProtocol(_BaseStreamProtocol): + def __init__(self, stream, client_connected_cb, loop=None, + *, _asyncio_internal=False): + super().__init__(loop=loop, _asyncio_internal=_asyncio_internal) + self._source_traceback = stream._source_traceback + self._stream = stream + self._client_connected_cb = client_connected_cb + + def connection_made(self, transport): + super().connection_made(transport) + res = self._client_connected_cb(self._stream, + self._stream) + if coroutines.iscoroutine(res): + self._loop.create_task(res) + + def connection_lost(self, exc): + super().connection_lost(exc) + self._stream = None + + class Stream: """Wraps a Transport. diff --git a/Lib/test/test_asyncio/test_streams.py b/Lib/test/test_asyncio/test_streams.py index f8cfa7208a47d8..279d0d4adfe41b 100644 --- a/Lib/test/test_asyncio/test_streams.py +++ b/Lib/test/test_asyncio/test_streams.py @@ -16,7 +16,7 @@ ssl = None import asyncio -from asyncio.streams import StreamReaderProtocol +from asyncio.streams import _StreamProtocol from test.test_asyncio import utils as test_utils @@ -838,11 +838,11 @@ def test_read_all_from_pipe_reader(self): args = [sys.executable, '-c', code, str(wfd)] pipe = open(rfd, 'rb', 0) - reader = asyncio.Stream(mode=asyncio.StreamMode.READ, + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, loop=self.loop, limit=1, _asyncio_internal=True) - protocol = StreamReaderProtocol(reader, loop=self.loop, - _asyncio_internal=True) + protocol = _StreamProtocol(stream, loop=self.loop, + _asyncio_internal=True) transport, _ = self.loop.run_until_complete( self.loop.connect_read_pipe(lambda: protocol, pipe)) @@ -859,14 +859,14 @@ def test_read_all_from_pipe_reader(self): asyncio.set_child_watcher(None) os.close(wfd) - data = self.loop.run_until_complete(reader.read(-1)) + data = self.loop.run_until_complete(stream.read(-1)) self.assertEqual(data, b'data') def test_streamreader_constructor(self): self.addCleanup(asyncio.set_event_loop, None) asyncio.set_event_loop(self.loop) - # asyncio issue #184: Ensure that StreamReaderProtocol constructor + # asyncio issue #184: Ensure that _StreamProtocol constructor # retrieves the current loop if the loop parameter is not set reader = asyncio.Stream(mode=asyncio.StreamMode.READ, _asyncio_internal=True) @@ -876,10 +876,10 @@ def test_streamreaderprotocol_constructor(self): self.addCleanup(asyncio.set_event_loop, None) asyncio.set_event_loop(self.loop) - # asyncio issue #184: Ensure that StreamReaderProtocol constructor + # asyncio issue #184: Ensure that _StreamProtocol constructor # retrieves the current loop if the loop parameter is not set - reader = mock.Mock() - protocol = StreamReaderProtocol(reader, _asyncio_internal=True) + stream = mock.Mock() + protocol = _StreamProtocol(stream, _asyncio_internal=True) self.assertIs(protocol._loop, self.loop) def test_drain_raises(self): @@ -1062,12 +1062,12 @@ def test_del_stream_before_connection_made(self): self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) with test_utils.run_test_server() as httpd: - rd = asyncio.Stream(mode=asyncio.StreamMode.READ, - loop=self.loop, - _asyncio_internal=True) - pr = StreamReaderProtocol(rd, loop=self.loop, - _asyncio_internal=True) - del rd + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) + pr = _StreamProtocol(stream, loop=self.loop, + _asyncio_internal=True) + del stream gc.collect() tr, _ = self.loop.run_until_complete( self.loop.create_connection( @@ -1145,10 +1145,6 @@ def test_stream_reader_create_warning(self): with self.assertWarns(DeprecationWarning): asyncio.StreamReader - def test_stream_reader_protocol_create_warning(self): - with self.assertRaises(AttributeError): - asyncio.StreamReaderProtocol - def test_stream_writer_create_warning(self): with self.assertWarns(DeprecationWarning): asyncio.StreamWriter diff --git a/Lib/test/test_asyncio/test_windows_events.py b/Lib/test/test_asyncio/test_windows_events.py index c776fcf1e117d7..0236c8ab882ab2 100644 --- a/Lib/test/test_asyncio/test_windows_events.py +++ b/Lib/test/test_asyncio/test_windows_events.py @@ -17,7 +17,7 @@ import asyncio from asyncio import windows_events -from asyncio.streams import StreamReaderProtocol +from asyncio.streams import _StreamProtocol from test.test_asyncio import utils as test_utils from test.support.script_helper import spawn_python @@ -103,8 +103,8 @@ async def _test_pipe(self): for i in range(5): stream = asyncio.Stream(mode=asyncio.StreamMode.READ, loop=self.loop) - protocol = StreamReaderProtocol(stream, - loop=self.loop) + protocol = _StreamProtocol(stream, + loop=self.loop) trans, proto = await self.loop.create_pipe_connection( lambda: protocol, ADDRESS) self.assertIsInstance(trans, asyncio.Transport) From 7bc248e8c1cddac8adb7727c310149f54b09a7eb Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Tue, 14 May 2019 13:53:22 +0300 Subject: [PATCH 24/82] Replace Enum with Flag --- Lib/asyncio/streams.py | 52 ++++++++++++++++++++---------------------- 1 file changed, 25 insertions(+), 27 deletions(-) diff --git a/Lib/asyncio/streams.py b/Lib/asyncio/streams.py index 6f3cd98a999634..b48ccbab409bb3 100644 --- a/Lib/asyncio/streams.py +++ b/Lib/asyncio/streams.py @@ -23,25 +23,19 @@ _DEFAULT_LIMIT = 2 ** 16 # 64 KiB -class StreamMode(enum.Enum): - READ = "read" - WRITE = "write" - READWRITE = "readwrite" +class StreamMode(enum.Flag): + READ = enum.auto() + WRITE = enum.auto() + READWRITE = READ | WRITE - def is_read(self): - return self in (self.READ, self.READWRITE) - - def is_write(self): - return self in (self.WRITE, self.READWRITE) + def _check_read(self): + if not self & self.READ: + raise RuntimeError("The stream is write-only") - def check_read(self): - if not self.is_read(): + def _check_write(self): + if not self & self.WRITE: raise RuntimeError("The stream is read-only") - def check_write(self): - if not self.is_write(): - raise RuntimeError("The stream is write-only") - async def open_connection(host=None, port=None, *, loop=None, limit=_DEFAULT_LIMIT, **kwds): @@ -438,17 +432,21 @@ def __repr__(self): info.append('paused') return '<{}>'.format(' '.join(info)) + @property + def mode(self): + return self._mode + @property def transport(self): return self._transport def write(self, data): - self._mode.check_write() + self._mode._check_write() self._transport.write(data) return self._fast_drain() def writelines(self, data): - self._mode.check_write() + self._mode._check_write() self._transport.writelines(data) return self._fast_drain() @@ -473,7 +471,7 @@ def _fast_drain(self): return self._loop.create_task(self.drain()) def write_eof(self): - self._mode.check_write() + self._mode._check_write() return self._transport.write_eof() def can_write_eof(self): @@ -502,7 +500,7 @@ async def drain(self): w.write(data) await w.drain() """ - self._mode.check_write() + self._mode._check_write() exc = self.exception() if exc is not None: raise exc @@ -547,17 +545,17 @@ def _maybe_resume_transport(self): self._transport.resume_reading() def feed_eof(self): - self._mode.check_read() + self._mode._check_read() self._eof = True self._wakeup_waiter() def at_eof(self): """Return True if the buffer is empty and 'feed_eof' was called.""" - self._mode.check_read() + self._mode._check_read() return self._eof and not self._buffer def feed_data(self, data): - self._mode.check_read() + self._mode._check_read() assert not self._eof, 'feed_data after feed_eof' if not data: @@ -623,7 +621,7 @@ async def readline(self): If stream was paused, this function will automatically resume it if needed. """ - self._mode.check_read() + self._mode._check_read() sep = b'\n' seplen = len(sep) try: @@ -659,7 +657,7 @@ async def readuntil(self, separator=b'\n'): LimitOverrunError exception will be raised, and the data will be left in the internal buffer, so it can be read again. """ - self._mode.check_read() + self._mode._check_read() seplen = len(separator) if seplen == 0: raise ValueError('Separator should be at least one-byte string') @@ -751,7 +749,7 @@ async def read(self, n=-1): If stream was paused, this function will automatically resume it if needed. """ - self._mode.check_read() + self._mode._check_read() if self._exception is not None: raise self._exception @@ -797,7 +795,7 @@ async def readexactly(self, n): If stream was paused, this function will automatically resume it if needed. """ - self._mode.check_read() + self._mode._check_read() if n < 0: raise ValueError('readexactly size can not be less than zero') @@ -825,7 +823,7 @@ async def readexactly(self, n): return data def __aiter__(self): - self._mode.check_read() + self._mode._check_read() return self async def __anext__(self): From 627028904a2919e680f3980a7c7584ebf0ba8eeb Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Tue, 14 May 2019 14:36:43 +0300 Subject: [PATCH 25/82] Add tests for stream mode --- Lib/test/test_asyncio/test_streams.py | 55 ++++++++++++++++++++++++++- 1 file changed, 54 insertions(+), 1 deletion(-) diff --git a/Lib/test/test_asyncio/test_streams.py b/Lib/test/test_asyncio/test_streams.py index 279d0d4adfe41b..cb84e9f01fff3f 100644 --- a/Lib/test/test_asyncio/test_streams.py +++ b/Lib/test/test_asyncio/test_streams.py @@ -24,6 +24,24 @@ def tearDownModule(): asyncio.set_event_loop_policy(None) +class StreamModeTests(unittest.TestCase): + def test__check_read_ok(self): + self.assertIsNone(asyncio.StreamMode.READ._check_read()) + self.assertIsNone(asyncio.StreamMode.READWRITE._check_read()) + + def test__check_read_fail(self): + with self.assertRaisesRegex(RuntimeError, "The stream is write-only"): + asyncio.StreamMode.WRITE._check_read() + + def test__check_write_ok(self): + self.assertIsNone(asyncio.StreamMode.WRITE._check_write()) + self.assertIsNone(asyncio.StreamMode.READWRITE._check_write()) + + def test__check_write_fail(self): + with self.assertRaisesRegex(RuntimeError, "The stream is read-only"): + asyncio.StreamMode.READ._check_write() + + class StreamTests(test_utils.TestCase): DATA = b'line1\nline2\nline3\n' @@ -1149,7 +1167,42 @@ def test_stream_writer_create_warning(self): with self.assertWarns(DeprecationWarning): asyncio.StreamWriter - + def test_stream_reader_forbidden_ops(self): + async def inner(): + stream = asyncio.Stream(mode=asyncio.StreamMode.READ) + with self.assertRaisesRegex(RuntimeError, "The stream is read-only"): + await stream.write(b'data') + with self.assertRaisesRegex(RuntimeError, "The stream is read-only"): + await stream.writelines([b'data', b'other']) + with self.assertRaisesRegex(RuntimeError, "The stream is read-only"): + stream.write_eof() + with self.assertRaisesRegex(RuntimeError, "The stream is read-only"): + await stream.drain() + + self.loop.run_until_complete(inner()) + + def test_stream_writer_forbidden_ops(self): + async def inner(): + stream = asyncio.Stream(mode=asyncio.StreamMode.WRITE) + with self.assertRaisesRegex(RuntimeError, "The stream is write-only"): + stream.feed_eof() + with self.assertRaisesRegex(RuntimeError, "The stream is write-only"): + stream.at_eof() + with self.assertRaisesRegex(RuntimeError, "The stream is write-only"): + stream.feed_data(b'data') + with self.assertRaisesRegex(RuntimeError, "The stream is write-only"): + await stream.readline() + with self.assertRaisesRegex(RuntimeError, "The stream is write-only"): + await stream.readuntil() + with self.assertRaisesRegex(RuntimeError, "The stream is write-only"): + await stream.read() + with self.assertRaisesRegex(RuntimeError, "The stream is write-only"): + await stream.readexactly(10) + with self.assertRaisesRegex(RuntimeError, "The stream is write-only"): + async for chunk in stream: + pass + + self.loop.run_until_complete(inner()) if __name__ == '__main__': unittest.main() From f04352baadc5a7fd29a5c2d9684ef60f9b25cdff Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Tue, 14 May 2019 14:46:37 +0300 Subject: [PATCH 26/82] Create a server stream in protocol.connection_made() --- Lib/asyncio/streams.py | 38 +++++++++++++-------------- Lib/test/test_asyncio/test_streams.py | 6 +++-- 2 files changed, 23 insertions(+), 21 deletions(-) diff --git a/Lib/asyncio/streams.py b/Lib/asyncio/streams.py index b48ccbab409bb3..09cdad8958b407 100644 --- a/Lib/asyncio/streams.py +++ b/Lib/asyncio/streams.py @@ -96,11 +96,8 @@ async def start_server(client_connected_cb, host=None, port=None, *, loop = events.get_event_loop() def factory(): - stream = Stream(mode=StreamMode.READWRITE, - limit=limit, - loop=loop, - _asyncio_internal=True) - protocol = _ServerStreamProtocol(stream, client_connected_cb, + protocol = _ServerStreamProtocol(limit, + client_connected_cb, loop=loop, _asyncio_internal=True) return protocol @@ -134,11 +131,7 @@ async def start_unix_server(client_connected_cb, path=None, *, loop = events.get_event_loop() def factory(): - stream = Stream(mode=StreamMode.READWRITE, - limit=limit, - loop=loop, - _asyncio_internal=True) - protocol = _ServerStreamProtocol(stream, + protocol = _ServerStreamProtocol(limit, client_connected_cb, loop=loop, _asyncio_internal=True) @@ -233,7 +226,6 @@ class _BaseStreamProtocol(FlowControlMixin, protocols.Protocol): call inappropriate methods of the protocol.) """ - _source_traceback = None _stream = None # initialized in derived classes def __init__(self, loop=None, @@ -245,11 +237,6 @@ def __init__(self, loop=None, def connection_made(self, transport): self._transport = transport - stream = self._stream - if stream is None: - return - stream.set_transport(transport) - stream._protocol = self self._over_ssl = transport.get_extra_info('sslcontext') is not None def connection_lost(self, exc): @@ -295,6 +282,8 @@ def __del__(self): class _StreamProtocol(_BaseStreamProtocol): + _source_traceback = None + def __init__(self, stream, loop=None, *, _asyncio_internal=False): super().__init__(loop=loop, _asyncio_internal=_asyncio_internal) @@ -337,6 +326,11 @@ def connection_made(self, transport): transport.abort() return super().connection_made(transport) + stream = self._stream + if stream is None: + return + stream.set_transport(transport) + stream._protocol = self def connection_lost(self, exc): super().connection_lost(exc) @@ -344,15 +338,21 @@ def connection_lost(self, exc): class _ServerStreamProtocol(_BaseStreamProtocol): - def __init__(self, stream, client_connected_cb, loop=None, + def __init__(self, limit, client_connected_cb, loop=None, *, _asyncio_internal=False): super().__init__(loop=loop, _asyncio_internal=_asyncio_internal) - self._source_traceback = stream._source_traceback - self._stream = stream self._client_connected_cb = client_connected_cb + self._limit = limit def connection_made(self, transport): super().connection_made(transport) + stream = Stream(mode=StreamMode.READWRITE, + transport=transport, + protocol=self, + limit=self._limit, + loop=self._loop, + _asyncio_internal=True) + self._stream = stream res = self._client_connected_cb(self._stream, self._stream) if coroutines.iscoroutine(res): diff --git a/Lib/test/test_asyncio/test_streams.py b/Lib/test/test_asyncio/test_streams.py index cb84e9f01fff3f..3d46f5ab131544 100644 --- a/Lib/test/test_asyncio/test_streams.py +++ b/Lib/test/test_asyncio/test_streams.py @@ -1169,7 +1169,8 @@ def test_stream_writer_create_warning(self): def test_stream_reader_forbidden_ops(self): async def inner(): - stream = asyncio.Stream(mode=asyncio.StreamMode.READ) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + _asyncio_internal=True) with self.assertRaisesRegex(RuntimeError, "The stream is read-only"): await stream.write(b'data') with self.assertRaisesRegex(RuntimeError, "The stream is read-only"): @@ -1183,7 +1184,8 @@ async def inner(): def test_stream_writer_forbidden_ops(self): async def inner(): - stream = asyncio.Stream(mode=asyncio.StreamMode.WRITE) + stream = asyncio.Stream(mode=asyncio.StreamMode.WRITE, + _asyncio_internal=True) with self.assertRaisesRegex(RuntimeError, "The stream is write-only"): stream.feed_eof() with self.assertRaisesRegex(RuntimeError, "The stream is write-only"): From 2cc72b8d477b16794d0195b9d0f7d7fb6fd1ef82 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Wed, 15 May 2019 12:30:56 +0300 Subject: [PATCH 27/82] Add conntect method --- Lib/asyncio/streams.py | 40 ++++++++++++++++++++++++--- Lib/test/test_asyncio/test_streams.py | 30 ++++++++++++++++++++ 2 files changed, 66 insertions(+), 4 deletions(-) diff --git a/Lib/asyncio/streams.py b/Lib/asyncio/streams.py index 1a7510d3fcbd8b..a24e47912775bc 100644 --- a/Lib/asyncio/streams.py +++ b/Lib/asyncio/streams.py @@ -1,6 +1,7 @@ __all__ = ( 'Stream', 'StreamMode', - 'open_connection', 'start_server') + 'open_connection', 'start_server', + 'connect') import enum import socket @@ -9,7 +10,8 @@ import weakref if hasattr(socket, 'AF_UNIX'): - __all__ += ('open_unix_connection', 'start_unix_server') + __all__ += ('open_unix_connection', 'start_unix_server', + 'connect_unix') from . import coroutines from . import events @@ -37,6 +39,20 @@ def _check_write(self): raise RuntimeError("The stream is read-only") +async def connect(host=None, port=None, *, + limit=_DEFAULT_LIMIT, **kwds): + loop = events.get_running_loop() + stream = Stream(mode=StreamMode.READWRITE, + limit=limit, + loop=loop, + _asyncio_internal=True) + await loop.create_connection( + lambda: _StreamProtocol(stream, loop=loop, + _asyncio_internal=True), + host, port, **kwds) + return stream + + async def open_connection(host=None, port=None, *, loop=None, limit=_DEFAULT_LIMIT, **kwds): """A wrapper for create_connection() returning a (reader, writer) pair. @@ -62,7 +78,7 @@ async def open_connection(host=None, port=None, *, limit=limit, loop=loop, _asyncio_internal=True) - transport, _ = await loop.create_connection( + await loop.create_connection( lambda: _StreamProtocol(stream, loop=loop, _asyncio_internal=True), host, port, **kwds) @@ -117,13 +133,29 @@ async def open_unix_connection(path=None, *, limit=limit, loop=loop, _asyncio_internal=True) - transport, _ = await loop.create_unix_connection( + await loop.create_unix_connection( lambda: _StreamProtocol(stream, loop=loop, _asyncio_internal=True), path, **kwds) return stream, stream + async def connect_unix(path=None, *, + loop=None, limit=_DEFAULT_LIMIT, **kwds): + """Similar to `connect()` but works with UNIX Domain Sockets.""" + loop = events.get_running_loop() + stream = Stream(mode=StreamMode.READWRITE, + limit=limit, + loop=loop, + _asyncio_internal=True) + await loop.create_unix_connection( + lambda: _StreamProtocol(stream, + loop=loop, + _asyncio_internal=True), + path, **kwds) + return stream + + async def start_unix_server(client_connected_cb, path=None, *, loop=None, limit=_DEFAULT_LIMIT, **kwds): """Similar to `start_server` but works with UNIX Domain Sockets.""" diff --git a/Lib/test/test_asyncio/test_streams.py b/Lib/test/test_asyncio/test_streams.py index 9f918534060535..818a415109ee71 100644 --- a/Lib/test/test_asyncio/test_streams.py +++ b/Lib/test/test_asyncio/test_streams.py @@ -1186,5 +1186,35 @@ async def inner(): self.loop.run_until_complete(inner()) + def _basetest_connect(self, stream): + messages = [] + self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) + + stream.write(b'GET / HTTP/1.0\r\n\r\n') + f = stream.readline() + data = self.loop.run_until_complete(f) + self.assertEqual(data, b'HTTP/1.0 200 OK\r\n') + f = stream.read() + data = self.loop.run_until_complete(f) + self.assertTrue(data.endswith(b'\r\n\r\nTest message')) + stream.close() + self.loop.run_until_complete(stream.wait_closed()) + + self.assertEqual([], messages) + + def test_connect(self): + with test_utils.run_test_server() as httpd: + stream = self.loop.run_until_complete( + asyncio.connect(*httpd.address)) + self._basetest_connect(stream) + + @support.skip_unless_bind_unix_socket + def test_connect_unix(self): + with test_utils.run_test_unix_server() as httpd: + stream = self.loop.run_until_complete( + asyncio.connect_unix(httpd.address)) + self._basetest_connect(stream) + + if __name__ == '__main__': unittest.main() From e7fcb06add5474380e2cada756042c33fc5f3e65 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Fri, 17 May 2019 14:55:01 +0300 Subject: [PATCH 28/82] Work on --- Lib/asyncio/streams.py | 100 ++++++++++++++++++++++++++++++++++++++--- 1 file changed, 94 insertions(+), 6 deletions(-) diff --git a/Lib/asyncio/streams.py b/Lib/asyncio/streams.py index a24e47912775bc..ea982181ca559c 100644 --- a/Lib/asyncio/streams.py +++ b/Lib/asyncio/streams.py @@ -40,7 +40,12 @@ def _check_write(self): async def connect(host=None, port=None, *, - limit=_DEFAULT_LIMIT, **kwds): + limit=_DEFAULT_LIMIT, + ssl=None, family=0, proto=0, + flags=0, sock=None, local_addr=None, + server_hostname=None, + ssl_handshake_timeout=None, + happy_eyeballs_delay=None, interleave=None): loop = events.get_running_loop() stream = Stream(mode=StreamMode.READWRITE, limit=limit, @@ -49,7 +54,12 @@ async def connect(host=None, port=None, *, await loop.create_connection( lambda: _StreamProtocol(stream, loop=loop, _asyncio_internal=True), - host, port, **kwds) + host, port, + ssl=ssl, family=family, proto=proto, + flags=flags, sock=sock, local_addr=local_addr, + server_hostname=server_hostname, + ssl_handshake_timeout=ssl_handshake_timeout, + happy_eyeballs_delay=happy_eyeballs_delay, interleave=interleave) return stream @@ -121,6 +131,75 @@ def factory(): return await loop.create_server(factory, host, port, **kwds) +class ServerStream: + def __init__(self, client_connected_cb, host=None, port=None, *, + limit=_DEFAULT_LIMIT, + family=socket.AF_UNSPEC, + flags=socket.AI_PASSIVE, sock=None, backlog=100, + ssl=None, reuse_address=None, reuse_port=None, + ssl_handshake_timeout=None): + self._client_connected_cb = client_connected_cb + self._host = host + self._port = port + self._limit = limit + self._family = family + self._flags = flags + self._sock = sock + self._backlog = backlog + self._ssl = ssl + self._reuse_address = reuse_address + self._reuse_port = reuse_port + self._ssl_handshake_timeout = ssl_handshake_timeout + self._loop = asyncio.get_running_loop() + self._low_server = None + + async def __aenter__(self): + def factory(): + protocol = _ServerStreamProtocol(self._limit, + self._client_connected_cb, + loop=self._loop, + _asyncio_internal=True) + return protocol + self._low_server = await self._loop.create_server( + factory, + self._host, + self._port, + start_serving=False, + family=self._family, + flags=self._flags, + sock=self._sock, + backlog=self._backlog, + ssl=self._ssl, + reuse_address=self._reuse_address, + reuse_port=self._reuse_port, + ssl_handshake_timeout=self._ssl_handshake_timeout) + return self + + async def __aexit__(self, exc_type, exc_value, exc_tb): + await self.close() + + def is_serving(self): + if self._low_server is None: + return False + return self._low_server.is_serving() + + async def start_serving(self): + assert self._low_server is not None + await self._low_server.start_serving() + + async def serve_forever(self): + assert self._low_server is not None + await self._low_server.serve_forever() + + async def close(self): + assert self._low_server is not None + self._low_server.close() + await self._low_server.wait_closed() + + async def abort(self): + pass + + if hasattr(socket, 'AF_UNIX'): # UNIX Domain Sockets are supported on this platform @@ -141,7 +220,10 @@ async def open_unix_connection(path=None, *, return stream, stream async def connect_unix(path=None, *, - loop=None, limit=_DEFAULT_LIMIT, **kwds): + limit=_DEFAULT_LIMIT, + ssl=None, sock=None, + server_hostname=None, + ssl_handshake_timeout=None): """Similar to `connect()` but works with UNIX Domain Sockets.""" loop = events.get_running_loop() stream = Stream(mode=StreamMode.READWRITE, @@ -152,7 +234,11 @@ async def connect_unix(path=None, *, lambda: _StreamProtocol(stream, loop=loop, _asyncio_internal=True), - path, **kwds) + path, + ssl=ssl, + sock=sock, + server_hostname=server_hostname, + ssl_handshake_timeout=ssl_handshake_timeout) return stream @@ -385,9 +471,9 @@ def connection_made(self, transport): loop=self._loop, _asyncio_internal=True) self._stream = stream - res = self._client_connected_cb(self._stream, - self._stream) + res = self._client_connected_cb(self._stream, self._stream) if coroutines.iscoroutine(res): + # TODO: wait for task finish in new API self._loop.create_task(res) def connection_lost(self, exc): @@ -399,6 +485,8 @@ def _swallow_unhandled_exception(task): # Do a trick to suppress unhandled exception # if stream.write() was used without await and # stream.drain() was paused and resumed with an exception + + # TODO: add if not task.cancelled() check!!!! task.exception() From 775860f55711f5638518bcee9a0f06bfa50df120 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Fri, 17 May 2019 17:45:18 +0300 Subject: [PATCH 29/82] Sketch StreamServer class --- Lib/asyncio/streams.py | 108 +++++++++++++++++++++++++++++++++-------- 1 file changed, 88 insertions(+), 20 deletions(-) diff --git a/Lib/asyncio/streams.py b/Lib/asyncio/streams.py index ea982181ca559c..f3b1e0ca6a497e 100644 --- a/Lib/asyncio/streams.py +++ b/Lib/asyncio/streams.py @@ -122,22 +122,25 @@ async def start_server(client_connected_cb, host=None, port=None, *, loop = events.get_event_loop() def factory(): - protocol = _ServerStreamProtocol(limit, - client_connected_cb, - loop=loop, - _asyncio_internal=True) + protocol = _LegacyServerStreamProtocol(limit, + client_connected_cb, + loop=loop, + _asyncio_internal=True) return protocol return await loop.create_server(factory, host, port, **kwds) class ServerStream: + # TODO: API for enumerating open server streams + def __init__(self, client_connected_cb, host=None, port=None, *, limit=_DEFAULT_LIMIT, family=socket.AF_UNSPEC, flags=socket.AI_PASSIVE, sock=None, backlog=100, ssl=None, reuse_address=None, reuse_port=None, - ssl_handshake_timeout=None): + ssl_handshake_timeout=None, + shutdown_timeout=60): self._client_connected_cb = client_connected_cb self._host = host self._port = port @@ -152,8 +155,12 @@ def __init__(self, client_connected_cb, host=None, port=None, *, self._ssl_handshake_timeout = ssl_handshake_timeout self._loop = asyncio.get_running_loop() self._low_server = None + self._streams = {} + self._shutdown_timeout = shutdown_timeout - async def __aenter__(self): + async def bind(self): + if self._low_server is not None: + return def factory(): protocol = _ServerStreamProtocol(self._limit, self._client_connected_cb, @@ -173,10 +180,6 @@ def factory(): reuse_address=self._reuse_address, reuse_port=self._reuse_port, ssl_handshake_timeout=self._ssl_handshake_timeout) - return self - - async def __aexit__(self, exc_type, exc_value, exc_tb): - await self.close() def is_serving(self): if self._low_server is None: @@ -184,20 +187,58 @@ def is_serving(self): return self._low_server.is_serving() async def start_serving(self): - assert self._low_server is not None + await self.bind() await self._low_server.start_serving() async def serve_forever(self): - assert self._low_server is not None + await self.start_serving() await self._low_server.serve_forever() async def close(self): - assert self._low_server is not None + if self._low_server is None: + return self._low_server.close() + tasks = list(self._streams.values()) + await asyncio.gather(*[stream.close() for stream in self._streams]) await self._low_server.wait_closed() + await self._warn_active_tasks(tasks) async def abort(self): - pass + if self._low_server is None: + return + self._low_server.close() + tasks = list(self._streams.values()) + await asyncio.gather(*[stream.abort() for stream in self._streams]) + await self._low_server.wait_closed() + await self._warn_active_tasks(tasks) + + async def __aenter__(self): + await self.bind() + return self + + async def __aexit__(self, exc_type, exc_value, exc_tb): + await self.close() + + def _attach(self, stream, task): + self._streams[stream] = task + + def _detach(self, stream, task): + del self._streams[stream] + + async def _warn_active_tasks(tasks): + if not tasks: + return + + done, pending = await asyncio.wait(tasks, timeout=self._shutdown_timeout) + if not pending: + return + for task in pending: + task.cancel() + done, pending = await asyncio.wait(pending, timeout=self._shutdown_timeout) + for task in pending: + self._loop.call_exception_handler({ + "message": f'{task} has not finished on stream server closing' + }) if hasattr(socket, 'AF_UNIX'): @@ -249,10 +290,10 @@ async def start_unix_server(client_connected_cb, path=None, *, loop = events.get_event_loop() def factory(): - protocol = _ServerStreamProtocol(limit, - client_connected_cb, - loop=loop, - _asyncio_internal=True) + protocol = _LegacyServerStreamProtocol(limit, + client_connected_cb, + loop=loop, + _asyncio_internal=True) return protocol return await loop.create_unix_server(factory, path, **kwds) @@ -455,7 +496,7 @@ def connection_lost(self, exc): self._stream_wr = None -class _ServerStreamProtocol(_BaseStreamProtocol): +class _LegacyServerStreamProtocol(_BaseStreamProtocol): def __init__(self, limit, client_connected_cb, loop=None, *, _asyncio_internal=False): super().__init__(loop=loop, _asyncio_internal=_asyncio_internal) @@ -473,7 +514,6 @@ def connection_made(self, transport): self._stream = stream res = self._client_connected_cb(self._stream, self._stream) if coroutines.iscoroutine(res): - # TODO: wait for task finish in new API self._loop.create_task(res) def connection_lost(self, exc): @@ -481,6 +521,34 @@ def connection_lost(self, exc): self._stream = None +class _ServerStreamProtocol(_BaseStreamProtocol): + def __init__(self, server, limit, client_connected_cb, loop=None, + *, _asyncio_internal=False): + super().__init__(loop=loop, _asyncio_internal=_asyncio_internal) + self._client_connected_cb = client_connected_cb + self._limit = limit + self._server = server + self._task = None + + def connection_made(self, transport): + super().connection_made(transport) + stream = Stream(mode=StreamMode.READWRITE, + transport=transport, + protocol=self, + limit=self._limit, + loop=self._loop, + _asyncio_internal=True) + self._stream = stream + self._task = self._loop.create_task( + self._client_connected_cb(self._stream)) + self._server._attach(stream, self._task) + + def connection_lost(self, exc): + super().connection_lost(exc) + self._server._detach(self._stream, self._task) + self._stream = None + + def _swallow_unhandled_exception(task): # Do a trick to suppress unhandled exception # if stream.write() was used without await and From 75421d751ee59982ce5fe0a1bc3238f3d3fbc613 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Mon, 20 May 2019 14:37:09 +0300 Subject: [PATCH 30/82] Work on --- Lib/asyncio/streams.py | 153 ++++++++++++++++++-------- Lib/test/test_asyncio/test_streams.py | 39 +++++++ 2 files changed, 146 insertions(+), 46 deletions(-) diff --git a/Lib/asyncio/streams.py b/Lib/asyncio/streams.py index f3b1e0ca6a497e..8f00ab65562097 100644 --- a/Lib/asyncio/streams.py +++ b/Lib/asyncio/streams.py @@ -1,7 +1,8 @@ __all__ = ( 'Stream', 'StreamMode', 'open_connection', 'start_server', - 'connect') + 'connect', + 'StreamServer') import enum import socket @@ -11,7 +12,8 @@ if hasattr(socket, 'AF_UNIX'): __all__ += ('open_unix_connection', 'start_unix_server', - 'connect_unix') + 'connect_unix', + 'UnixStreamServer') from . import coroutines from . import events @@ -19,7 +21,7 @@ from . import format_helpers from . import protocols from .log import logger -from .tasks import sleep +from .tasks import gather, sleep, wait _DEFAULT_LIMIT = 2 ** 16 # 64 KiB @@ -131,29 +133,18 @@ def factory(): return await loop.create_server(factory, host, port, **kwds) -class ServerStream: +class _BaseStreamServer: # TODO: API for enumerating open server streams - def __init__(self, client_connected_cb, host=None, port=None, *, - limit=_DEFAULT_LIMIT, - family=socket.AF_UNSPEC, - flags=socket.AI_PASSIVE, sock=None, backlog=100, - ssl=None, reuse_address=None, reuse_port=None, - ssl_handshake_timeout=None, - shutdown_timeout=60): + def __init__(self, client_connected_cb, + limit=_DEFAULT_LIMIT, + shutdown_timeout=60, + _asyncio_internal=False): + if not _asyncio_internal: + raise RuntimeError("_ServerStream is a private asyncio class") self._client_connected_cb = client_connected_cb - self._host = host - self._port = port self._limit = limit - self._family = family - self._flags = flags - self._sock = sock - self._backlog = backlog - self._ssl = ssl - self._reuse_address = reuse_address - self._reuse_port = reuse_port - self._ssl_handshake_timeout = ssl_handshake_timeout - self._loop = asyncio.get_running_loop() + self._loop = events.get_running_loop() self._low_server = None self._streams = {} self._shutdown_timeout = shutdown_timeout @@ -161,25 +152,10 @@ def __init__(self, client_connected_cb, host=None, port=None, *, async def bind(self): if self._low_server is not None: return - def factory(): - protocol = _ServerStreamProtocol(self._limit, - self._client_connected_cb, - loop=self._loop, - _asyncio_internal=True) - return protocol - self._low_server = await self._loop.create_server( - factory, - self._host, - self._port, - start_serving=False, - family=self._family, - flags=self._flags, - sock=self._sock, - backlog=self._backlog, - ssl=self._ssl, - reuse_address=self._reuse_address, - reuse_port=self._reuse_port, - ssl_handshake_timeout=self._ssl_handshake_timeout) + self._low_server = await self._bind() + + def is_bound(self): + return self._low_server is not None def is_serving(self): if self._low_server is None: @@ -199,7 +175,7 @@ async def close(self): return self._low_server.close() tasks = list(self._streams.values()) - await asyncio.gather(*[stream.close() for stream in self._streams]) + await gather(*[stream.close() for stream in self._streams]) await self._low_server.wait_closed() await self._warn_active_tasks(tasks) @@ -208,7 +184,7 @@ async def abort(self): return self._low_server.close() tasks = list(self._streams.values()) - await asyncio.gather(*[stream.abort() for stream in self._streams]) + await gather(*[stream.abort() for stream in self._streams]) await self._low_server.wait_closed() await self._warn_active_tasks(tasks) @@ -225,22 +201,68 @@ def _attach(self, stream, task): def _detach(self, stream, task): del self._streams[stream] - async def _warn_active_tasks(tasks): + async def _warn_active_tasks(self, tasks): if not tasks: return - done, pending = await asyncio.wait(tasks, timeout=self._shutdown_timeout) + done, pending = await wait(tasks, timeout=self._shutdown_timeout) if not pending: return for task in pending: task.cancel() - done, pending = await asyncio.wait(pending, timeout=self._shutdown_timeout) + done, pending = await wait(pending, timeout=self._shutdown_timeout) for task in pending: self._loop.call_exception_handler({ "message": f'{task} has not finished on stream server closing' }) +class StreamServer(_BaseStreamServer): + + def __init__(self, client_connected_cb, host=None, port=None, *, + limit=_DEFAULT_LIMIT, + family=socket.AF_UNSPEC, + flags=socket.AI_PASSIVE, sock=None, backlog=100, + ssl=None, reuse_address=None, reuse_port=None, + ssl_handshake_timeout=None, + shutdown_timeout=60): + super().__init__(client_connected_cb, + limit=limit, + shutdown_timeout=shutdown_timeout, + _asyncio_internal=True) + self._host = host + self._port = port + self._family = family + self._flags = flags + self._sock = sock + self._backlog = backlog + self._ssl = ssl + self._reuse_address = reuse_address + self._reuse_port = reuse_port + self._ssl_handshake_timeout = ssl_handshake_timeout + + async def _bind(self): + def factory(): + protocol = _ServerStreamProtocol(self._limit, + self._client_connected_cb, + loop=self._loop, + _asyncio_internal=True) + return protocol + return await self._loop.create_server( + factory, + self._host, + self._port, + start_serving=False, + family=self._family, + flags=self._flags, + sock=self._sock, + backlog=self._backlog, + ssl=self._ssl, + reuse_address=self._reuse_address, + reuse_port=self._reuse_port, + ssl_handshake_timeout=self._ssl_handshake_timeout) + + if hasattr(socket, 'AF_UNIX'): # UNIX Domain Sockets are supported on this platform @@ -298,6 +320,41 @@ def factory(): return await loop.create_unix_server(factory, path, **kwds) + class UnixStreamServer(_BaseStreamServer): + + def __init__(self, client_connected_cb, path=None, *, + limit=_DEFAULT_LIMIT, + sock=None, + backlog=100, + ssl=None, + ssl_handshake_timeout=None, + shutdown_timeout=60): + super().__init__(client_connected_cb, + limit=limit, + shutdown_timeout=shutdown_timeout, + _asyncio_internal=True) + self._path = path + self._sock = sock + self._backlog = backlog + self._ssl = ssl + self._ssl_handshake_timeout = ssl_handshake_timeout + + async def _bind(self): + def factory(): + protocol = _ServerStreamProtocol(self._limit, + self._client_connected_cb, + loop=self._loop, + _asyncio_internal=True) + return protocol + return await self._loop.create_unix_server( + factory, + self._path, + start_serving=False, + sock=self._sock, + backlog=self._backlog, + ssl=self._ssl, + ssl_handshake_timeout=self._ssl_handshake_timeout) + class FlowControlMixin(protocols.Protocol): """Reusable flow control logic for StreamWriter.drain(). @@ -683,6 +740,10 @@ def close(self): def is_closing(self): return self._transport.is_closing() + async def abort(self): + self._transport.abort() + await self.wait_closed() + async def wait_closed(self): await self._protocol._get_close_waiter(self) diff --git a/Lib/test/test_asyncio/test_streams.py b/Lib/test/test_asyncio/test_streams.py index 818a415109ee71..fcd2ecde0908da 100644 --- a/Lib/test/test_asyncio/test_streams.py +++ b/Lib/test/test_asyncio/test_streams.py @@ -1,5 +1,6 @@ """Tests for streams.py.""" +import contextlib import gc import os import queue @@ -1215,6 +1216,44 @@ def test_connect_unix(self): asyncio.connect_unix(httpd.address)) self._basetest_connect(stream) + def test_stream_server(self): + + @contextlib.asynccontextmanager + async def server(): + + async def handle_client(self, stream): + data = await stream.readline() + await stream.write(data) + await stream.close() + + sock = socket.create_server(('127.0.0.1', 0)) + async with asyncio.StreamServer(handle_client, sock=sock) as server: + yield server, sock.getsockname() + await server.serve_forever() + + async def client(srv, addr): + stream = await asyncio.connect(*addr) + # send a line + await stream.write(b"hello world!\n") + # read it back + msgback = await stream.readline() + await stream.close() + self.assertEqual(msgback, b"hello world!\n") + await srv.close() + + async def test(): + async with server() as (srv, addr): + task = asyncio.create_task(client(srv, addr)) + await srv.serve_forever() + await task + + messages = [] + self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) + + self.loop.run_until_complete(test()) + self.assertEqual(messages, []) + + if __name__ == '__main__': unittest.main() From 7d11e4fad54dd75342a6cc7b7e79a51e6e72a9c6 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Tue, 21 May 2019 14:49:28 +0300 Subject: [PATCH 31/82] Fix stream test --- Lib/asyncio/streams.py | 39 ++++++++++++++++++++++++--- Lib/test/test_asyncio/test_streams.py | 29 +++++++++----------- 2 files changed, 48 insertions(+), 20 deletions(-) diff --git a/Lib/asyncio/streams.py b/Lib/asyncio/streams.py index 8f00ab65562097..c98bd419fde9b1 100644 --- a/Lib/asyncio/streams.py +++ b/Lib/asyncio/streams.py @@ -135,6 +135,14 @@ def factory(): class _BaseStreamServer: # TODO: API for enumerating open server streams + # TODO: add __repr__ + + # Design note. + # StreamServer and UnixStreamServer are exposed as FINAL classes, not function + # factories. + # async with serve(host, port) as server: + # server.start_serving() + # looks ugly. def __init__(self, client_connected_cb, limit=_DEFAULT_LIMIT, @@ -155,9 +163,24 @@ async def bind(self): self._low_server = await self._bind() def is_bound(self): + # TODO: make is_bound and is_serving properties? return self._low_server is not None + def served_names(self): + # The property name is questionable + # Also, should it be a property? + # API consistency does matter + # I don't want to expose plain socket.socket objects as low-level + # asyncio.Server does but exposing served IP addresses or unix paths + # is useful + # + # multiple value for socket bound to both IPv4 and IPv6 families + if self._low_server is None: + return [] + return [sock.getsockname() for sock in self._low_server.sockets] + def is_serving(self): + # TODO: make is_bound and is_serving properties? if self._low_server is None: return False return self._low_server.is_serving() @@ -226,6 +249,10 @@ def __init__(self, client_connected_cb, host=None, port=None, *, ssl=None, reuse_address=None, reuse_port=None, ssl_handshake_timeout=None, shutdown_timeout=60): + # client_connected_cb name is consistent with legacy API + # but it is long and ugly + # any suggestion? + super().__init__(client_connected_cb, limit=limit, shutdown_timeout=shutdown_timeout, @@ -243,7 +270,8 @@ def __init__(self, client_connected_cb, host=None, port=None, *, async def _bind(self): def factory(): - protocol = _ServerStreamProtocol(self._limit, + protocol = _ServerStreamProtocol(self, + self._limit, self._client_connected_cb, loop=self._loop, _asyncio_internal=True) @@ -341,7 +369,8 @@ def __init__(self, client_connected_cb, path=None, *, async def _bind(self): def factory(): - protocol = _ServerStreamProtocol(self._limit, + protocol = _ServerStreamProtocol(self, + self._limit, self._client_connected_cb, loop=self._loop, _asyncio_internal=True) @@ -492,7 +521,7 @@ def _get_close_waiter(self, stream): def __del__(self): # Prevent reports about unhandled exceptions. # Better than self._closed._log_traceback = False hack - closed = self._closed + closed = self._get_close_waiter(self._stream) if closed.done() and not closed.cancelled(): closed.exception() @@ -582,6 +611,7 @@ class _ServerStreamProtocol(_BaseStreamProtocol): def __init__(self, server, limit, client_connected_cb, loop=None, *, _asyncio_internal=False): super().__init__(loop=loop, _asyncio_internal=_asyncio_internal) + assert self._closed self._client_connected_cb = client_connected_cb self._limit = limit self._server = server @@ -596,6 +626,9 @@ def connection_made(self, transport): loop=self._loop, _asyncio_internal=True) self._stream = stream + # TODO: log a case when task cannot be created. + # Usualy it means that _client_connected_cb + # has incompatible signature. self._task = self._loop.create_task( self._client_connected_cb(self._stream)) self._server._attach(stream, self._task) diff --git a/Lib/test/test_asyncio/test_streams.py b/Lib/test/test_asyncio/test_streams.py index fcd2ecde0908da..71bbb2996a8537 100644 --- a/Lib/test/test_asyncio/test_streams.py +++ b/Lib/test/test_asyncio/test_streams.py @@ -1218,20 +1218,14 @@ def test_connect_unix(self): def test_stream_server(self): - @contextlib.asynccontextmanager - async def server(): - - async def handle_client(self, stream): - data = await stream.readline() - await stream.write(data) - await stream.close() + async def handle_client(stream): + data = await stream.readline() + await stream.write(data) + await stream.close() - sock = socket.create_server(('127.0.0.1', 0)) - async with asyncio.StreamServer(handle_client, sock=sock) as server: - yield server, sock.getsockname() - await server.serve_forever() - async def client(srv, addr): + async def client(srv): + addr = srv.served_names()[0] stream = await asyncio.connect(*addr) # send a line await stream.write(b"hello world!\n") @@ -1242,15 +1236,16 @@ async def client(srv, addr): await srv.close() async def test(): - async with server() as (srv, addr): - task = asyncio.create_task(client(srv, addr)) - await srv.serve_forever() - await task + async with asyncio.StreamServer(handle_client, '127.0.0.1', 0) as server: + await server.start_serving() + asyncio.create_task(client(server)) + await server.serve_forever() messages = [] self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) - self.loop.run_until_complete(test()) + with contextlib.suppress(asyncio.CancelledError): + self.loop.run_until_complete(test()) self.assertEqual(messages, []) From 03cc501755dea75c8ad0b3137d0cf94d055ec38e Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Tue, 21 May 2019 15:03:11 +0300 Subject: [PATCH 32/82] Test cleanup --- Lib/test/test_asyncio/test_streams.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/Lib/test/test_asyncio/test_streams.py b/Lib/test/test_asyncio/test_streams.py index 71bbb2996a8537..e08ad55a711f80 100644 --- a/Lib/test/test_asyncio/test_streams.py +++ b/Lib/test/test_asyncio/test_streams.py @@ -1238,14 +1238,14 @@ async def client(srv): async def test(): async with asyncio.StreamServer(handle_client, '127.0.0.1', 0) as server: await server.start_serving() - asyncio.create_task(client(server)) - await server.serve_forever() + task = asyncio.create_task(client(server)) + with contextlib.suppress(asyncio.CancelledError): + await server.serve_forever() + await task messages = [] self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) - - with contextlib.suppress(asyncio.CancelledError): - self.loop.run_until_complete(test()) + self.loop.run_until_complete(test()) self.assertEqual(messages, []) From ca7f4790bf27b2c174d86380b0dbb7cbc858f025 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Tue, 21 May 2019 15:11:42 +0300 Subject: [PATCH 33/82] Add unix server test --- Lib/test/test_asyncio/test_streams.py | 34 ++++++++++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/Lib/test/test_asyncio/test_streams.py b/Lib/test/test_asyncio/test_streams.py index 93c9865024caf7..2accfd913cf0cd 100644 --- a/Lib/test/test_asyncio/test_streams.py +++ b/Lib/test/test_asyncio/test_streams.py @@ -1222,7 +1222,6 @@ async def handle_client(stream): await stream.write(data) await stream.close() - async def client(srv): addr = srv.served_names()[0] stream = await asyncio.connect(*addr) @@ -1247,6 +1246,39 @@ async def test(): self.loop.run_until_complete(test()) self.assertEqual(messages, []) + @support.skip_unless_bind_unix_socket + def test_unix_stream_server(self): + + async def handle_client(stream): + data = await stream.readline() + await stream.write(data) + await stream.close() + + async def client(srv): + addr = srv.served_names()[0] + stream = await asyncio.connect_unix(addr) + # send a line + await stream.write(b"hello world!\n") + # read it back + msgback = await stream.readline() + await stream.close() + self.assertEqual(msgback, b"hello world!\n") + await srv.close() + + async def test(): + with test_utils.unix_socket_path() as path: + async with asyncio.UnixStreamServer(handle_client, path) as server: + await server.start_serving() + task = asyncio.create_task(client(server)) + with contextlib.suppress(asyncio.CancelledError): + await server.serve_forever() + await task + + messages = [] + self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) + self.loop.run_until_complete(test()) + self.assertEqual(messages, []) + if __name__ == '__main__': From 7ff3b9fa357623047f40fe4f5df4b270ec58734d Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Tue, 21 May 2019 15:24:30 +0300 Subject: [PATCH 34/82] Add tests for server.bind() --- Lib/test/test_asyncio/test_streams.py | 31 +++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/Lib/test/test_asyncio/test_streams.py b/Lib/test/test_asyncio/test_streams.py index 2accfd913cf0cd..2d72bab52d6b1f 100644 --- a/Lib/test/test_asyncio/test_streams.py +++ b/Lib/test/test_asyncio/test_streams.py @@ -1279,6 +1279,37 @@ async def test(): self.loop.run_until_complete(test()) self.assertEqual(messages, []) + def test_stream_server_bind(self): + async def handle_client(stream): + await stream.close() + + async def test(): + srv = asyncio.StreamServer(handle_client, '127.0.0.1', 0) + self.assertFalse(srv.is_bound()) + self.assertEqual([], srv.served_names()) + await srv.bind() + self.assertTrue(srv.is_bound()) + self.assertEqual(1, len(srv.served_names())) + await srv.close() + + messages = [] + self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) + self.loop.run_until_complete(test()) + self.assertEqual(messages, []) + + def test_stream_server_bind_async_with(self): + async def handle_client(stream): + await stream.close() + + async def test(): + async with asyncio.StreamServer(handle_client, '127.0.0.1', 0) as srv: + self.assertTrue(srv.is_bound()) + self.assertEqual(1, len(srv.served_names())) + + messages = [] + self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) + self.loop.run_until_complete(test()) + self.assertEqual(messages, []) if __name__ == '__main__': From a34bbed175b4a55467acc95343d35e34c7bf9536 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Tue, 21 May 2019 15:24:55 +0300 Subject: [PATCH 35/82] Rename private method --- Lib/asyncio/streams.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/Lib/asyncio/streams.py b/Lib/asyncio/streams.py index c98bd419fde9b1..9b3b5f8222a131 100644 --- a/Lib/asyncio/streams.py +++ b/Lib/asyncio/streams.py @@ -200,7 +200,7 @@ async def close(self): tasks = list(self._streams.values()) await gather(*[stream.close() for stream in self._streams]) await self._low_server.wait_closed() - await self._warn_active_tasks(tasks) + await self._shutdown_active_tasks(tasks) async def abort(self): if self._low_server is None: @@ -209,7 +209,7 @@ async def abort(self): tasks = list(self._streams.values()) await gather(*[stream.abort() for stream in self._streams]) await self._low_server.wait_closed() - await self._warn_active_tasks(tasks) + await self._shutdown_active_tasks(tasks) async def __aenter__(self): await self.bind() @@ -224,7 +224,7 @@ def _attach(self, stream, task): def _detach(self, stream, task): del self._streams[stream] - async def _warn_active_tasks(self, tasks): + async def _shutdown_active_tasks(self, tasks): if not tasks: return From 94861968590d8df538507e97ac830c3c6388a7d2 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Tue, 21 May 2019 15:26:35 +0300 Subject: [PATCH 36/82] More tests --- Lib/test/test_asyncio/test_streams.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/Lib/test/test_asyncio/test_streams.py b/Lib/test/test_asyncio/test_streams.py index 2d72bab52d6b1f..0c49d5acbdcb56 100644 --- a/Lib/test/test_asyncio/test_streams.py +++ b/Lib/test/test_asyncio/test_streams.py @@ -1311,6 +1311,21 @@ async def test(): self.loop.run_until_complete(test()) self.assertEqual(messages, []) + def test_stream_server_start_serving(self): + async def handle_client(stream): + await stream.close() + + async def test(): + async with asyncio.StreamServer(handle_client, '127.0.0.1', 0) as srv: + self.assertFalse(srv.is_serving()) + await srv.start_serving() + self.assertTrue(srv.is_serving()) + + messages = [] + self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) + self.loop.run_until_complete(test()) + self.assertEqual(messages, []) + if __name__ == '__main__': unittest.main() From 4a4a2e085651502526ec6ce4fcc10959b39e2255 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Tue, 21 May 2019 15:51:37 +0300 Subject: [PATCH 37/82] More tests --- Lib/asyncio/streams.py | 5 ++++- Lib/test/test_asyncio/test_streams.py | 4 ++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/Lib/asyncio/streams.py b/Lib/asyncio/streams.py index 9b3b5f8222a131..50d7a635b98740 100644 --- a/Lib/asyncio/streams.py +++ b/Lib/asyncio/streams.py @@ -201,6 +201,7 @@ async def close(self): await gather(*[stream.close() for stream in self._streams]) await self._low_server.wait_closed() await self._shutdown_active_tasks(tasks) + self._low_server = None async def abort(self): if self._low_server is None: @@ -210,6 +211,7 @@ async def abort(self): await gather(*[stream.abort() for stream in self._streams]) await self._low_server.wait_closed() await self._shutdown_active_tasks(tasks) + self._low_server = None async def __aenter__(self): await self.bind() @@ -227,7 +229,8 @@ def _detach(self, stream, task): async def _shutdown_active_tasks(self, tasks): if not tasks: return - + # NOTE: tasks finished with exception are reported + # by Tast/Future __del__ method done, pending = await wait(tasks, timeout=self._shutdown_timeout) if not pending: return diff --git a/Lib/test/test_asyncio/test_streams.py b/Lib/test/test_asyncio/test_streams.py index 0c49d5acbdcb56..7d4a01942bd0fb 100644 --- a/Lib/test/test_asyncio/test_streams.py +++ b/Lib/test/test_asyncio/test_streams.py @@ -1291,6 +1291,8 @@ async def test(): self.assertTrue(srv.is_bound()) self.assertEqual(1, len(srv.served_names())) await srv.close() + self.assertFalse(srv.is_bound()) + self.assertEqual([], srv.served_names()) messages = [] self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) @@ -1320,6 +1322,8 @@ async def test(): self.assertFalse(srv.is_serving()) await srv.start_serving() self.assertTrue(srv.is_serving()) + await srv.close() + self.assertFalse(srv.is_serving()) messages = [] self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) From 414939bd718e29c004616f2c9cab404ceb73632e Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Tue, 21 May 2019 16:03:24 +0300 Subject: [PATCH 38/82] Drop low-level server early --- Lib/asyncio/streams.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/Lib/asyncio/streams.py b/Lib/asyncio/streams.py index 50d7a635b98740..801d1397912435 100644 --- a/Lib/asyncio/streams.py +++ b/Lib/asyncio/streams.py @@ -21,7 +21,7 @@ from . import format_helpers from . import protocols from .log import logger -from .tasks import gather, sleep, wait +from .tasks import sleep, wait _DEFAULT_LIMIT = 2 ** 16 # 64 KiB @@ -198,20 +198,20 @@ async def close(self): return self._low_server.close() tasks = list(self._streams.values()) - await gather(*[stream.close() for stream in self._streams]) + await wait([stream.close() for stream in self._streams]) await self._low_server.wait_closed() - await self._shutdown_active_tasks(tasks) self._low_server = None + await self._shutdown_active_tasks(tasks) async def abort(self): if self._low_server is None: return self._low_server.close() tasks = list(self._streams.values()) - await gather(*[stream.abort() for stream in self._streams]) + await wait([stream.abort() for stream in self._streams]) await self._low_server.wait_closed() - await self._shutdown_active_tasks(tasks) self._low_server = None + await self._shutdown_active_tasks(tasks) async def __aenter__(self): await self.bind() From c62b8b40f25eb33fc1a6a7fae6eab49e3b9139bf Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Tue, 21 May 2019 17:44:02 +0300 Subject: [PATCH 39/82] More tests --- Lib/asyncio/streams.py | 10 ++++- Lib/test/test_asyncio/test_streams.py | 65 +++++++++++++++++++++++++++ 2 files changed, 73 insertions(+), 2 deletions(-) diff --git a/Lib/asyncio/streams.py b/Lib/asyncio/streams.py index 801d1397912435..18896e686c8889 100644 --- a/Lib/asyncio/streams.py +++ b/Lib/asyncio/streams.py @@ -197,8 +197,10 @@ async def close(self): if self._low_server is None: return self._low_server.close() + streams = list(self._streams.keys()) tasks = list(self._streams.values()) - await wait([stream.close() for stream in self._streams]) + if tasks: + await wait([stream.close() for stream in streams]) await self._low_server.wait_closed() self._low_server = None await self._shutdown_active_tasks(tasks) @@ -207,8 +209,10 @@ async def abort(self): if self._low_server is None: return self._low_server.close() + streams = list(self._streams.keys()) tasks = list(self._streams.values()) - await wait([stream.abort() for stream in self._streams]) + if streams: + await wait([stream.abort() for stream in streams]) await self._low_server.wait_closed() self._low_server = None await self._shutdown_active_tasks(tasks) @@ -661,6 +665,8 @@ class Stream: directly. """ + # TODO: add __aenter__ / __aexit__ to close stream + _source_traceback = None def __init__(self, mode, *, diff --git a/Lib/test/test_asyncio/test_streams.py b/Lib/test/test_asyncio/test_streams.py index 7d4a01942bd0fb..1b556970e0efc0 100644 --- a/Lib/test/test_asyncio/test_streams.py +++ b/Lib/test/test_asyncio/test_streams.py @@ -1330,6 +1330,71 @@ async def test(): self.loop.run_until_complete(test()) self.assertEqual(messages, []) + def test_stream_server_close(self): + server_stream_aborted = False + fut = self.loop.create_future() + + async def handle_client(stream): + await fut + self.assertEqual(b'', await stream.readline()) + nonlocal server_stream_aborted + server_stream_aborted = True + + async def client(srv): + addr = srv.served_names()[0] + stream = await asyncio.connect(*addr) + fut.set_result(None) + self.assertEqual(b'', await stream.readline()) + await stream.close() + + async def test(): + async with asyncio.StreamServer(handle_client, '127.0.0.1', 0) as server: + await server.start_serving() + task = asyncio.create_task(client(server)) + await fut + await server.close() + await task + + messages = [] + self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) + self.loop.run_until_complete(test()) + self.assertEqual(messages, []) + self.assertTrue(fut.done()) + self.assertTrue(server_stream_aborted) + + + def test_stream_server_abort(self): + server_stream_aborted = False + fut = self.loop.create_future() + + async def handle_client(stream): + await fut + self.assertEqual(b'', await stream.readline()) + nonlocal server_stream_aborted + server_stream_aborted = True + + async def client(srv): + addr = srv.served_names()[0] + stream = await asyncio.connect(*addr) + fut.set_result(None) + self.assertEqual(b'', await stream.readline()) + await stream.close() + + async def test(): + async with asyncio.StreamServer(handle_client, '127.0.0.1', 0) as server: + await server.start_serving() + task = asyncio.create_task(client(server)) + await fut + await server.abort() + await task + + messages = [] + self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) + self.loop.run_until_complete(test()) + self.assertEqual(messages, []) + self.assertTrue(fut.done()) + self.assertTrue(server_stream_aborted) + if __name__ == '__main__': unittest.main() From b84cf84dab51fda8fd2e768989f9994207c3ce8e Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Tue, 21 May 2019 18:37:17 +0300 Subject: [PATCH 40/82] Add more tests --- Lib/asyncio/streams.py | 8 ++- Lib/test/test_asyncio/test_streams.py | 77 ++++++++++++++++++++++++++- 2 files changed, 83 insertions(+), 2 deletions(-) diff --git a/Lib/asyncio/streams.py b/Lib/asyncio/streams.py index 18896e686c8889..2ef6984843db7d 100644 --- a/Lib/asyncio/streams.py +++ b/Lib/asyncio/streams.py @@ -243,7 +243,7 @@ async def _shutdown_active_tasks(self, tasks): done, pending = await wait(pending, timeout=self._shutdown_timeout) for task in pending: self._loop.call_exception_handler({ - "message": f'{task} has not finished on stream server closing' + "message": f'{task} was not finished on stream server closing' }) @@ -1129,3 +1129,9 @@ async def __anext__(self): if val == b'': raise StopAsyncIteration return val + + # async def __aenter__(self): + # return self + + # async def __aexit__(self, exc_type, exc_val, exc_tb): + # await self.close() diff --git a/Lib/test/test_asyncio/test_streams.py b/Lib/test/test_asyncio/test_streams.py index 1b556970e0efc0..0023ea71260d83 100644 --- a/Lib/test/test_asyncio/test_streams.py +++ b/Lib/test/test_asyncio/test_streams.py @@ -1362,7 +1362,6 @@ async def test(): self.assertTrue(fut.done()) self.assertTrue(server_stream_aborted) - def test_stream_server_abort(self): server_stream_aborted = False fut = self.loop.create_future() @@ -1395,6 +1394,82 @@ async def test(): self.assertTrue(fut.done()) self.assertTrue(server_stream_aborted) + def test_stream_shutdown_hung_task(self): + fut1 = self.loop.create_future() + fut2 = self.loop.create_future() + + async def handle_client(stream): + while True: + await asyncio.sleep(0.01) + + async def client(srv): + addr = srv.served_names()[0] + stream = await asyncio.connect(*addr) + fut1.set_result(None) + await fut2 + self.assertEqual(b'', await stream.readline()) + await stream.close() + + async def test(): + async with asyncio.StreamServer(handle_client, + '127.0.0.1', + 0, + shutdown_timeout=0.5) as server: + await server.start_serving() + task = asyncio.create_task(client(server)) + await fut1 + await server.close() + fut2.set_result(None) + await task + + messages = [] + self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) + self.loop.run_until_complete(test()) + self.assertEqual(messages, []) + self.assertTrue(fut1.done()) + self.assertTrue(fut2.done()) + + def test_stream_shutdown_hung_task_prevents_cancellation(self): + fut1 = self.loop.create_future() + fut2 = self.loop.create_future() + do_handle_client = True + + async def handle_client(stream): + while do_handle_client: + with contextlib.suppress(asyncio.CancelledError): + await asyncio.sleep(0.01) + + async def client(srv): + addr = srv.served_names()[0] + stream = await asyncio.connect(*addr) + fut1.set_result(None) + await fut2 + self.assertEqual(b'', await stream.readline()) + await stream.close() + + async def test(): + async with asyncio.StreamServer(handle_client, + '127.0.0.1', + 0, + shutdown_timeout=0.5) as server: + await server.start_serving() + task = asyncio.create_task(client(server)) + await fut1 + await server.close() + nonlocal do_handle_client + do_handle_client = False + fut2.set_result(None) + await task + + messages = [] + self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) + self.loop.run_until_complete(test()) + self.assertEqual(1, len(messages)) + self.assertRegex(messages[0]['message'], + " Date: Wed, 22 May 2019 11:59:52 +0300 Subject: [PATCH 41/82] Forbid stream servers inheritance --- Lib/asyncio/streams.py | 4 ++++ Lib/test/test_asyncio/test_streams.py | 11 +++++++++++ 2 files changed, 15 insertions(+) diff --git a/Lib/asyncio/streams.py b/Lib/asyncio/streams.py index 2ef6984843db7d..561fd38307549e 100644 --- a/Lib/asyncio/streams.py +++ b/Lib/asyncio/streams.py @@ -224,6 +224,10 @@ async def __aenter__(self): async def __aexit__(self, exc_type, exc_value, exc_tb): await self.close() + def __init_subclass__(cls): + if not cls.__module__.startswith('asyncio.'): + raise TypeError("Stream server classes are final, don't inherit from them") + def _attach(self, stream, task): self._streams[stream] = task diff --git a/Lib/test/test_asyncio/test_streams.py b/Lib/test/test_asyncio/test_streams.py index 0023ea71260d83..5a6349a9d28324 100644 --- a/Lib/test/test_asyncio/test_streams.py +++ b/Lib/test/test_asyncio/test_streams.py @@ -1279,6 +1279,17 @@ async def test(): self.loop.run_until_complete(test()) self.assertEqual(messages, []) + def test_stream_server_inheritance_forbidden(self): + with self.assertRaises(TypeError): + class MyServer(asyncio.StreamServer): + pass + + @support.skip_unless_bind_unix_socket + def test_unix_stream_server_inheritance_forbidden(self): + with self.assertRaises(TypeError): + class MyServer(asyncio.UnixStreamServer): + pass + def test_stream_server_bind(self): async def handle_client(stream): await stream.close() From 37a2949a95d8515cac0135ff7f3ec0987a1c8caa Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Wed, 22 May 2019 12:57:02 +0300 Subject: [PATCH 42/82] Add tests for stream sendfile --- Lib/asyncio/streams.py | 5 ++++ Lib/test/test_asyncio/test_streams.py | 37 +++++++++++++++++++++++++++ 2 files changed, 42 insertions(+) diff --git a/Lib/asyncio/streams.py b/Lib/asyncio/streams.py index 561fd38307549e..0344baf93655f2 100644 --- a/Lib/asyncio/streams.py +++ b/Lib/asyncio/streams.py @@ -815,6 +815,11 @@ async def drain(self): await sleep(0) await self._protocol._drain_helper() + async def sendfile(self, file, offset=0, count=None, *, fallback=True): + await self.drain() # check for stream mode and exceptions + return await self._loop.sendfile(self._transport, file, + offset, count, fallback=fallback) + def exception(self): return self._exception diff --git a/Lib/test/test_asyncio/test_streams.py b/Lib/test/test_asyncio/test_streams.py index 5a6349a9d28324..d1ab0491341a28 100644 --- a/Lib/test/test_asyncio/test_streams.py +++ b/Lib/test/test_asyncio/test_streams.py @@ -1481,6 +1481,43 @@ async def test(): self.assertTrue(fut1.done()) self.assertTrue(fut2.done()) + def test_sendfile(self): + messages = [] + self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) + + with open(support.TESTFN, 'wb') as fp: + fp.write(b'data\n') + self.addCleanup(support.unlink, support.TESTFN) + + async def serve_callback(stream): + data = await stream.readline() + self.assertEqual(data, b'begin\n') + data = await stream.readline() + self.assertEqual(data, b'data\n') + data = await stream.readline() + self.assertEqual(data, b'end\n') + await stream.write(b'done\n') + await stream.close() + + async def do_connect(host, port): + stream = await asyncio.connect(host, port) + await stream.write(b'begin\n') + with open(support.TESTFN, 'rb') as fp: + await stream.sendfile(fp) + await stream.write(b'end\n') + data = await stream.readline() + self.assertEqual(data, b'done\n') + await stream.close() + + async def test(): + async with asyncio.StreamServer(serve_callback, 'localhost', 0) as srv: + await srv.start_serving() + await do_connect(*srv.served_names()[0]) + + self.loop.run_until_complete(test()) + + self.assertEqual([], messages) + if __name__ == '__main__': unittest.main() From c73ddfd9e69029ae2dac478e171a42c5f9bf517d Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Wed, 22 May 2019 13:11:06 +0300 Subject: [PATCH 43/82] Add start_tls method --- Lib/asyncio/streams.py | 20 ++++++++++++++++++++ Lib/test/test_asyncio/test_streams.py | 14 ++++++++++++++ 2 files changed, 34 insertions(+) diff --git a/Lib/asyncio/streams.py b/Lib/asyncio/streams.py index 0344baf93655f2..de5f1666a1c119 100644 --- a/Lib/asyncio/streams.py +++ b/Lib/asyncio/streams.py @@ -607,6 +607,7 @@ def connection_made(self, transport): protocol=self, limit=self._limit, loop=self._loop, + is_server_side=True, _asyncio_internal=True) self._stream = stream res = self._client_connected_cb(self._stream, self._stream) @@ -635,6 +636,7 @@ def connection_made(self, transport): protocol=self, limit=self._limit, loop=self._loop, + is_server_side=True, _asyncio_internal=True) self._stream = stream # TODO: log a case when task cannot be created. @@ -678,6 +680,7 @@ def __init__(self, mode, *, protocol=None, loop=None, limit=_DEFAULT_LIMIT, + is_server_side=False, _asyncio_internal=False): if not _asyncio_internal: warnings.warn(f"{self.__class__} should be instaniated " @@ -687,6 +690,7 @@ def __init__(self, mode, *, self._mode = mode self._transport = transport self._protocol = protocol + self._is_server_side = is_server_side # The line length limit is a security feature; # it also doubles as half the buffer limit. @@ -734,6 +738,9 @@ def __repr__(self): def mode(self): return self._mode + def is_server_side(self): + return self._is_server_side + @property def transport(self): return self._transport @@ -820,6 +827,19 @@ async def sendfile(self, file, offset=0, count=None, *, fallback=True): return await self._loop.sendfile(self._transport, file, offset, count, fallback=fallback) + async def start_tls(self, sslcontext, *, + server_hostname=None, + ssl_handshake_timeout=None): + await self.drain() # check for stream mode and exceptions + transport = await self._loop.start_tls( + self._transport, self._protocol, sslcontext, + server_side=self._is_server_side, + server_hostname=server_hostname, + ssl_handshake_timeout=ssl_handshake_timeout) + self._transport = transport + self._protocol._transport = transport + self._protocol._over_ssl = True + def exception(self): return self._exception diff --git a/Lib/test/test_asyncio/test_streams.py b/Lib/test/test_asyncio/test_streams.py index d1ab0491341a28..a8ba0a2d2f3809 100644 --- a/Lib/test/test_asyncio/test_streams.py +++ b/Lib/test/test_asyncio/test_streams.py @@ -1519,5 +1519,19 @@ async def test(): self.assertEqual([], messages) + @unittest.skipIf(ssl is None, 'No ssl module') + def test_connect_start_tls(self): + with test_utils.run_test_server(use_ssl=True) as httpd: + # connect without SSL but upgrade to TLS just after + # connection is established + stream = self.loop.run_until_complete( + asyncio.connect(*httpd.address)) + + self.loop.run_until_complete( + stream.start_tls( + sslcontext=test_utils.dummy_ssl_context())) + self._basetest_connect(stream) + + if __name__ == '__main__': unittest.main() From 3f9bc91e26dc37a2b1765d391dc52fc9afa665c1 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Wed, 22 May 2019 13:35:51 +0300 Subject: [PATCH 44/82] served_names() -> addresses() --- Lib/asyncio/streams.py | 5 +---- Lib/test/test_asyncio/test_streams.py | 26 +++++++++++++------------- 2 files changed, 14 insertions(+), 17 deletions(-) diff --git a/Lib/asyncio/streams.py b/Lib/asyncio/streams.py index de5f1666a1c119..e2337704c3c69d 100644 --- a/Lib/asyncio/streams.py +++ b/Lib/asyncio/streams.py @@ -166,10 +166,7 @@ def is_bound(self): # TODO: make is_bound and is_serving properties? return self._low_server is not None - def served_names(self): - # The property name is questionable - # Also, should it be a property? - # API consistency does matter + def addresses(self): # I don't want to expose plain socket.socket objects as low-level # asyncio.Server does but exposing served IP addresses or unix paths # is useful diff --git a/Lib/test/test_asyncio/test_streams.py b/Lib/test/test_asyncio/test_streams.py index a8ba0a2d2f3809..3f6d2fc145902f 100644 --- a/Lib/test/test_asyncio/test_streams.py +++ b/Lib/test/test_asyncio/test_streams.py @@ -1223,7 +1223,7 @@ async def handle_client(stream): await stream.close() async def client(srv): - addr = srv.served_names()[0] + addr = srv.addresses()[0] stream = await asyncio.connect(*addr) # send a line await stream.write(b"hello world!\n") @@ -1255,7 +1255,7 @@ async def handle_client(stream): await stream.close() async def client(srv): - addr = srv.served_names()[0] + addr = srv.addresses()[0] stream = await asyncio.connect_unix(addr) # send a line await stream.write(b"hello world!\n") @@ -1297,13 +1297,13 @@ async def handle_client(stream): async def test(): srv = asyncio.StreamServer(handle_client, '127.0.0.1', 0) self.assertFalse(srv.is_bound()) - self.assertEqual([], srv.served_names()) + self.assertEqual([], srv.addresses()) await srv.bind() self.assertTrue(srv.is_bound()) - self.assertEqual(1, len(srv.served_names())) + self.assertEqual(1, len(srv.addresses())) await srv.close() self.assertFalse(srv.is_bound()) - self.assertEqual([], srv.served_names()) + self.assertEqual([], srv.addresses()) messages = [] self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) @@ -1317,7 +1317,7 @@ async def handle_client(stream): async def test(): async with asyncio.StreamServer(handle_client, '127.0.0.1', 0) as srv: self.assertTrue(srv.is_bound()) - self.assertEqual(1, len(srv.served_names())) + self.assertEqual(1, len(srv.addresses())) messages = [] self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) @@ -1352,7 +1352,7 @@ async def handle_client(stream): server_stream_aborted = True async def client(srv): - addr = srv.served_names()[0] + addr = srv.addresses()[0] stream = await asyncio.connect(*addr) fut.set_result(None) self.assertEqual(b'', await stream.readline()) @@ -1384,7 +1384,7 @@ async def handle_client(stream): server_stream_aborted = True async def client(srv): - addr = srv.served_names()[0] + addr = srv.addresses()[0] stream = await asyncio.connect(*addr) fut.set_result(None) self.assertEqual(b'', await stream.readline()) @@ -1414,7 +1414,7 @@ async def handle_client(stream): await asyncio.sleep(0.01) async def client(srv): - addr = srv.served_names()[0] + addr = srv.addresses()[0] stream = await asyncio.connect(*addr) fut1.set_result(None) await fut2 @@ -1425,7 +1425,7 @@ async def test(): async with asyncio.StreamServer(handle_client, '127.0.0.1', 0, - shutdown_timeout=0.5) as server: + shutdown_timeout=0.3) as server: await server.start_serving() task = asyncio.create_task(client(server)) await fut1 @@ -1451,7 +1451,7 @@ async def handle_client(stream): await asyncio.sleep(0.01) async def client(srv): - addr = srv.served_names()[0] + addr = srv.addresses()[0] stream = await asyncio.connect(*addr) fut1.set_result(None) await fut2 @@ -1462,7 +1462,7 @@ async def test(): async with asyncio.StreamServer(handle_client, '127.0.0.1', 0, - shutdown_timeout=0.5) as server: + shutdown_timeout=0.3) as server: await server.start_serving() task = asyncio.create_task(client(server)) await fut1 @@ -1512,7 +1512,7 @@ async def do_connect(host, port): async def test(): async with asyncio.StreamServer(serve_callback, 'localhost', 0) as srv: await srv.start_serving() - await do_connect(*srv.served_names()[0]) + await do_connect(*srv.addresses()[0]) self.loop.run_until_complete(test()) From 7eca1fc76af3ac814fcf2c2dc09460891a73963c Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Wed, 22 May 2019 13:38:11 +0300 Subject: [PATCH 45/82] Drop obsolete TODOs --- Lib/asyncio/streams.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/Lib/asyncio/streams.py b/Lib/asyncio/streams.py index e2337704c3c69d..3d0bac73d48a0a 100644 --- a/Lib/asyncio/streams.py +++ b/Lib/asyncio/streams.py @@ -163,7 +163,6 @@ async def bind(self): self._low_server = await self._bind() def is_bound(self): - # TODO: make is_bound and is_serving properties? return self._low_server is not None def addresses(self): @@ -177,7 +176,6 @@ def addresses(self): return [sock.getsockname() for sock in self._low_server.sockets] def is_serving(self): - # TODO: make is_bound and is_serving properties? if self._low_server is None: return False return self._low_server.is_serving() From ab41dd6bf52fdc678863120a3b43f407d24c1a31 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Wed, 22 May 2019 13:42:50 +0300 Subject: [PATCH 46/82] Drop redundant TODO --- Lib/asyncio/streams.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/Lib/asyncio/streams.py b/Lib/asyncio/streams.py index 3d0bac73d48a0a..bc26c9b8b5aa3f 100644 --- a/Lib/asyncio/streams.py +++ b/Lib/asyncio/streams.py @@ -634,9 +634,8 @@ def connection_made(self, transport): is_server_side=True, _asyncio_internal=True) self._stream = stream - # TODO: log a case when task cannot be created. - # Usualy it means that _client_connected_cb - # has incompatible signature. + # If self._client_connected_cb(self._stream) fails + # the exception is logged by transport self._task = self._loop.create_task( self._client_connected_cb(self._stream)) self._server._attach(stream, self._task) From 1716e921e82ba7ed4d23e12a834e57a38bbc683a Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Wed, 22 May 2019 13:45:51 +0300 Subject: [PATCH 47/82] Add test for stream.is_server_side() --- Lib/test/test_asyncio/test_streams.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/Lib/test/test_asyncio/test_streams.py b/Lib/test/test_asyncio/test_streams.py index 3f6d2fc145902f..096b5da0a636ba 100644 --- a/Lib/test/test_asyncio/test_streams.py +++ b/Lib/test/test_asyncio/test_streams.py @@ -1206,6 +1206,7 @@ def test_connect(self): with test_utils.run_test_server() as httpd: stream = self.loop.run_until_complete( asyncio.connect(*httpd.address)) + self.assertFalse(stream.is_server_side()) self._basetest_connect(stream) @support.skip_unless_bind_unix_socket @@ -1218,6 +1219,7 @@ def test_connect_unix(self): def test_stream_server(self): async def handle_client(stream): + self.assertTrue(stream.is_server_side()) data = await stream.readline() await stream.write(data) await stream.close() From 6a41b1fd778479dbe18a868e35500a911b25695d Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Wed, 22 May 2019 17:54:28 +0300 Subject: [PATCH 48/82] Polish --- Lib/asyncio/streams.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/Lib/asyncio/streams.py b/Lib/asyncio/streams.py index bc26c9b8b5aa3f..16b509b8a5173f 100644 --- a/Lib/asyncio/streams.py +++ b/Lib/asyncio/streams.py @@ -134,15 +134,16 @@ def factory(): class _BaseStreamServer: - # TODO: API for enumerating open server streams # TODO: add __repr__ - # Design note. + # Design notes. # StreamServer and UnixStreamServer are exposed as FINAL classes, not function # factories. # async with serve(host, port) as server: # server.start_serving() # looks ugly. + # The class doesn't provide API for enumerating connected streams + # It can be a subject for improvements in Python 3.9 def __init__(self, client_connected_cb, limit=_DEFAULT_LIMIT, @@ -245,6 +246,11 @@ async def _shutdown_active_tasks(self, tasks): "message": f'{task} was not finished on stream server closing' }) + def __del__(self, _warn=warnings.warn): + if self._low_server is not None: + _warn(f"unclosed stream server {self!r}", ResourceWarning, source=self) + self._low_server.close() + class StreamServer(_BaseStreamServer): From 774e9f15ab4ec966433951664e5aa19f56dfe412 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Wed, 22 May 2019 23:39:12 +0300 Subject: [PATCH 49/82] Fix imports --- Lib/asyncio/streams.py | 44 +++++++++++++++++++++++------------------- 1 file changed, 24 insertions(+), 20 deletions(-) diff --git a/Lib/asyncio/streams.py b/Lib/asyncio/streams.py index 16b509b8a5173f..3e61935e0768f6 100644 --- a/Lib/asyncio/streams.py +++ b/Lib/asyncio/streams.py @@ -21,7 +21,7 @@ from . import format_helpers from . import protocols from .log import logger -from .tasks import sleep, wait +from . import tasks _DEFAULT_LIMIT = 2 ** 16 # 64 KiB @@ -137,8 +137,8 @@ class _BaseStreamServer: # TODO: add __repr__ # Design notes. - # StreamServer and UnixStreamServer are exposed as FINAL classes, not function - # factories. + # StreamServer and UnixStreamServer are exposed as FINAL classes, + # not function factories. # async with serve(host, port) as server: # server.start_serving() # looks ugly. @@ -194,24 +194,24 @@ async def close(self): return self._low_server.close() streams = list(self._streams.keys()) - tasks = list(self._streams.values()) - if tasks: - await wait([stream.close() for stream in streams]) + active_tasks = list(self._streams.values()) + if streams: + await tasks.wait([stream.close() for stream in streams]) await self._low_server.wait_closed() self._low_server = None - await self._shutdown_active_tasks(tasks) + await self._shutdown_active_tasks(active_tasks) async def abort(self): if self._low_server is None: return self._low_server.close() streams = list(self._streams.keys()) - tasks = list(self._streams.values()) + active_tasks = list(self._streams.values()) if streams: - await wait([stream.abort() for stream in streams]) + await tasks.wait([stream.abort() for stream in streams]) await self._low_server.wait_closed() self._low_server = None - await self._shutdown_active_tasks(tasks) + await self._shutdown_active_tasks(active_tasks) async def __aenter__(self): await self.bind() @@ -222,7 +222,8 @@ async def __aexit__(self, exc_type, exc_value, exc_tb): def __init_subclass__(cls): if not cls.__module__.startswith('asyncio.'): - raise TypeError("Stream server classes are final, don't inherit from them") + raise TypeError("Stream server classes are final, " + "don't inherit from them") def _attach(self, stream, task): self._streams[stream] = task @@ -230,17 +231,19 @@ def _attach(self, stream, task): def _detach(self, stream, task): del self._streams[stream] - async def _shutdown_active_tasks(self, tasks): - if not tasks: + async def _shutdown_active_tasks(self, active_tasks): + if not active_tasks: return # NOTE: tasks finished with exception are reported # by Tast/Future __del__ method - done, pending = await wait(tasks, timeout=self._shutdown_timeout) + done, pending = await tasks.wait(active_tasks, + timeout=self._shutdown_timeout) if not pending: return for task in pending: task.cancel() - done, pending = await wait(pending, timeout=self._shutdown_timeout) + done, pending = await tasks.wait(pending, + timeout=self._shutdown_timeout) for task in pending: self._loop.call_exception_handler({ "message": f'{task} was not finished on stream server closing' @@ -248,7 +251,8 @@ async def _shutdown_active_tasks(self, tasks): def __del__(self, _warn=warnings.warn): if self._low_server is not None: - _warn(f"unclosed stream server {self!r}", ResourceWarning, source=self) + _warn(f"unclosed stream server {self!r}", + ResourceWarning, source=self) self._low_server.close() @@ -756,9 +760,9 @@ def writelines(self, data): return self._fast_drain() def _fast_drain(self): - # The helper tries to use fast-path to return already existing complete future - # object if underlying transport is not paused and actual waiting for writing - # resume is not needed + # The helper tries to use fast-path to return already existing + # complete future object if underlying transport is not paused + #and actual waiting for writing resume is not needed exc = self.exception() if exc is not None: fut = self._loop.create_future() @@ -819,7 +823,7 @@ async def drain(self): # Wait for protocol.connection_lost() call # Raise connection closing error if any, # ConnectionResetError otherwise - await sleep(0) + await tasks.sleep(0) await self._protocol._drain_helper() async def sendfile(self, file, offset=0, count=None, *, fallback=True): From c54f64b33b380a2afe56c352840433b1dd25243a Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Wed, 22 May 2019 23:40:58 +0300 Subject: [PATCH 50/82] Fix inheritance error --- Lib/asyncio/streams.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/Lib/asyncio/streams.py b/Lib/asyncio/streams.py index 3e61935e0768f6..06a461b1e25e76 100644 --- a/Lib/asyncio/streams.py +++ b/Lib/asyncio/streams.py @@ -222,8 +222,7 @@ async def __aexit__(self, exc_type, exc_value, exc_tb): def __init_subclass__(cls): if not cls.__module__.startswith('asyncio.'): - raise TypeError("Stream server classes are final, " - "don't inherit from them") + raise TypeError(f"asyncio.{cls.__name__} class cannot be inherited from") def _attach(self, stream, task): self._streams[stream] = task From efcb32654a30465a0c1d5b9bc9f3a92b18dbfbc2 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Wed, 22 May 2019 23:41:50 +0300 Subject: [PATCH 51/82] Update Lib/asyncio/streams.py Co-Authored-By: Yury Selivanov --- Lib/asyncio/streams.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Lib/asyncio/streams.py b/Lib/asyncio/streams.py index 06a461b1e25e76..d409a6df9eaeab 100644 --- a/Lib/asyncio/streams.py +++ b/Lib/asyncio/streams.py @@ -234,7 +234,7 @@ async def _shutdown_active_tasks(self, active_tasks): if not active_tasks: return # NOTE: tasks finished with exception are reported - # by Tast/Future __del__ method + # by the Task.__del__() method. done, pending = await tasks.wait(active_tasks, timeout=self._shutdown_timeout) if not pending: From 066d31746f40827b2c49ae49e7dd8b391ebf96d6 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Wed, 22 May 2019 23:42:13 +0300 Subject: [PATCH 52/82] Update Lib/asyncio/streams.py Co-Authored-By: Yury Selivanov --- Lib/asyncio/streams.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/Lib/asyncio/streams.py b/Lib/asyncio/streams.py index d409a6df9eaeab..612e3b73c9f4c8 100644 --- a/Lib/asyncio/streams.py +++ b/Lib/asyncio/streams.py @@ -245,7 +245,8 @@ async def _shutdown_active_tasks(self, active_tasks): timeout=self._shutdown_timeout) for task in pending: self._loop.call_exception_handler({ - "message": f'{task} was not finished on stream server closing' + "message": f'{task!r} ignored cancellation request from a closing StreamServer {self!r}', + "stream_server": self }) def __del__(self, _warn=warnings.warn): From b64c9af79605ed3bb977c884bba09b964728aeb3 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Wed, 22 May 2019 23:51:51 +0300 Subject: [PATCH 53/82] Update Lib/asyncio/streams.py Co-Authored-By: Yury Selivanov --- Lib/asyncio/streams.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Lib/asyncio/streams.py b/Lib/asyncio/streams.py index 612e3b73c9f4c8..724bad6aeece0c 100644 --- a/Lib/asyncio/streams.py +++ b/Lib/asyncio/streams.py @@ -167,7 +167,7 @@ def is_bound(self): return self._low_server is not None def addresses(self): - # I don't want to expose plain socket.socket objects as low-level + # We don't want to expose plain socket.socket objects as low-level # asyncio.Server does but exposing served IP addresses or unix paths # is useful # From 12e4f07390a046447d128eca31dd945e43f74ef3 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Thu, 23 May 2019 00:15:01 +0300 Subject: [PATCH 54/82] Make client_connected_cb positional-only --- Lib/asyncio/streams.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/Lib/asyncio/streams.py b/Lib/asyncio/streams.py index 06a461b1e25e76..23e1a0b64c5ea5 100644 --- a/Lib/asyncio/streams.py +++ b/Lib/asyncio/streams.py @@ -146,6 +146,7 @@ class _BaseStreamServer: # It can be a subject for improvements in Python 3.9 def __init__(self, client_connected_cb, + /, limit=_DEFAULT_LIMIT, shutdown_timeout=60, _asyncio_internal=False): @@ -257,17 +258,13 @@ def __del__(self, _warn=warnings.warn): class StreamServer(_BaseStreamServer): - def __init__(self, client_connected_cb, host=None, port=None, *, + def __init__(self, client_connected_cb, /, host=None, port=None, *, limit=_DEFAULT_LIMIT, family=socket.AF_UNSPEC, flags=socket.AI_PASSIVE, sock=None, backlog=100, ssl=None, reuse_address=None, reuse_port=None, ssl_handshake_timeout=None, shutdown_timeout=60): - # client_connected_cb name is consistent with legacy API - # but it is long and ugly - # any suggestion? - super().__init__(client_connected_cb, limit=limit, shutdown_timeout=shutdown_timeout, @@ -365,7 +362,7 @@ def factory(): class UnixStreamServer(_BaseStreamServer): - def __init__(self, client_connected_cb, path=None, *, + def __init__(self, client_connected_cb, /, path=None, *, limit=_DEFAULT_LIMIT, sock=None, backlog=100, From 035cf3f194940e66868900f26f86703d682966d7 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Thu, 23 May 2019 12:23:38 +0300 Subject: [PATCH 55/82] addresses -> listeners --- Lib/asyncio/streams.py | 70 +++++++++++++++++++++++---- Lib/test/test_asyncio/test_streams.py | 24 ++++----- 2 files changed, 72 insertions(+), 22 deletions(-) diff --git a/Lib/asyncio/streams.py b/Lib/asyncio/streams.py index 8a8224b1124fae..e4ff259db9977c 100644 --- a/Lib/asyncio/streams.py +++ b/Lib/asyncio/streams.py @@ -1,5 +1,5 @@ __all__ = ( - 'Stream', 'StreamMode', + 'Stream', 'StreamMode', 'Listener', 'open_connection', 'start_server', 'connect', 'StreamServer') @@ -40,6 +40,57 @@ def _check_write(self): if not self & self.WRITE: raise RuntimeError("The stream is read-only") +class Listener: + # The class represents listening sockets served by + # StreamServer and UnixStreamServer + # The exposed API is IPv4/IPv6/UNIX address itself + # and socket fileno + # The socket object is not exposed intentionally + # to prevent wild access from high-level API + # sock = socket.fromfd(fileno, family, type, proto) + # still can be used if the hacking is really necessary + + __slots__ = ('_family', '_type', '_proto', '_addr', '_fileno') + + def __init__(self, sock, _asyncio_internal=False): + if not _asyncio_internal: + raise TypeError(f"{self.__class__} should be instaniated " + "by asyncio internals only") + self._family = sock.family + self._type = sock.type + self._proto = sock.proto + self._addr = sock.getsockname() + self._fileno = sock.fileno() + + @property + def family(self): + return self._family + + @property + def type(self): + return self._type + + @property + def proto(self): + return self._proto + + @property + def addr(self): + return self._addr + + def fileno(self): + return self._fileno + + def __repr__(self): + ret = [f"<{self.__class__.__name__}"] + ret.append(f"family={self.family}") + ret.append(f"type={self.type}") + ret.append(f"proto={self.proto}") + ret.append(f"addr={self.addr}") + ret.append(f"fileno={self.fileno()}") + ret.append(">") + return " ".join(ret) + async def connect(host=None, port=None, *, limit=_DEFAULT_LIMIT, @@ -167,15 +218,12 @@ async def bind(self): def is_bound(self): return self._low_server is not None - def addresses(self): - # We don't want to expose plain socket.socket objects as low-level - # asyncio.Server does but exposing served IP addresses or unix paths - # is useful - # + def listeners(self): # multiple value for socket bound to both IPv4 and IPv6 families if self._low_server is None: - return [] - return [sock.getsockname() for sock in self._low_server.sockets] + return tuple() + return tuple(Listener(sock, _asyncio_internal=True) + for sock in self._low_server.sockets) def is_serving(self): if self._low_server is None: @@ -223,7 +271,8 @@ async def __aexit__(self, exc_type, exc_value, exc_tb): def __init_subclass__(cls): if not cls.__module__.startswith('asyncio.'): - raise TypeError(f"asyncio.{cls.__name__} class cannot be inherited from") + raise TypeError(f"asyncio.{cls.__name__} " + "class cannot be inherited from") def _attach(self, stream, task): self._streams[stream] = task @@ -246,7 +295,8 @@ async def _shutdown_active_tasks(self, active_tasks): timeout=self._shutdown_timeout) for task in pending: self._loop.call_exception_handler({ - "message": f'{task!r} ignored cancellation request from a closing StreamServer {self!r}', + "message": (f'{task!r} ignored cancellation request ' + f'from a closing {self!r}'), "stream_server": self }) diff --git a/Lib/test/test_asyncio/test_streams.py b/Lib/test/test_asyncio/test_streams.py index 096b5da0a636ba..1b91c2d7c155d5 100644 --- a/Lib/test/test_asyncio/test_streams.py +++ b/Lib/test/test_asyncio/test_streams.py @@ -1225,7 +1225,7 @@ async def handle_client(stream): await stream.close() async def client(srv): - addr = srv.addresses()[0] + addr = srv.listeners()[0].addr stream = await asyncio.connect(*addr) # send a line await stream.write(b"hello world!\n") @@ -1257,7 +1257,7 @@ async def handle_client(stream): await stream.close() async def client(srv): - addr = srv.addresses()[0] + addr = srv.listeners()[0].addr stream = await asyncio.connect_unix(addr) # send a line await stream.write(b"hello world!\n") @@ -1299,13 +1299,13 @@ async def handle_client(stream): async def test(): srv = asyncio.StreamServer(handle_client, '127.0.0.1', 0) self.assertFalse(srv.is_bound()) - self.assertEqual([], srv.addresses()) + self.assertEqual(0, len(srv.listeners())) await srv.bind() self.assertTrue(srv.is_bound()) - self.assertEqual(1, len(srv.addresses())) + self.assertEqual(1, len(srv.listeners())) await srv.close() self.assertFalse(srv.is_bound()) - self.assertEqual([], srv.addresses()) + self.assertEqual(0, len(srv.listeners())) messages = [] self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) @@ -1319,7 +1319,7 @@ async def handle_client(stream): async def test(): async with asyncio.StreamServer(handle_client, '127.0.0.1', 0) as srv: self.assertTrue(srv.is_bound()) - self.assertEqual(1, len(srv.addresses())) + self.assertEqual(1, len(srv.listeners())) messages = [] self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) @@ -1354,7 +1354,7 @@ async def handle_client(stream): server_stream_aborted = True async def client(srv): - addr = srv.addresses()[0] + addr = srv.listeners()[0].addr stream = await asyncio.connect(*addr) fut.set_result(None) self.assertEqual(b'', await stream.readline()) @@ -1386,7 +1386,7 @@ async def handle_client(stream): server_stream_aborted = True async def client(srv): - addr = srv.addresses()[0] + addr = srv.listeners()[0].addr stream = await asyncio.connect(*addr) fut.set_result(None) self.assertEqual(b'', await stream.readline()) @@ -1416,7 +1416,7 @@ async def handle_client(stream): await asyncio.sleep(0.01) async def client(srv): - addr = srv.addresses()[0] + addr = srv.listeners()[0].addr stream = await asyncio.connect(*addr) fut1.set_result(None) await fut2 @@ -1453,7 +1453,7 @@ async def handle_client(stream): await asyncio.sleep(0.01) async def client(srv): - addr = srv.addresses()[0] + addr = srv.listeners()[0].addr stream = await asyncio.connect(*addr) fut1.set_result(None) await fut2 @@ -1479,7 +1479,7 @@ async def test(): self.loop.run_until_complete(test()) self.assertEqual(1, len(messages)) self.assertRegex(messages[0]['message'], - " Date: Thu, 23 May 2019 12:44:24 +0300 Subject: [PATCH 56/82] Revert back StreamReader and StreamWriter --- Lib/asyncio/__init__.py | 9 +- Lib/asyncio/streams.py | 568 +++++++++++++++++++++++--- Lib/test/test_asyncio/test_streams.py | 61 +-- 3 files changed, 553 insertions(+), 85 deletions(-) diff --git a/Lib/asyncio/__init__.py b/Lib/asyncio/__init__.py index 8b986e86d17590..00d8bd1ee3c68b 100644 --- a/Lib/asyncio/__init__.py +++ b/Lib/asyncio/__init__.py @@ -50,15 +50,20 @@ def __getattr__(name): + global StreamReader, StreamWriter if name == 'StreamReader': warnings.warn("StreamReader is deprecated, use asyncio.Stream instead", DeprecationWarning, stacklevel=2) - return Stream + from .streams import StreamReader as sr + StreamReader = sr + return StreamReader if name == 'StreamWriter': warnings.warn("StreamWriter is deprecated, use asyncio.Stream instead", DeprecationWarning, stacklevel=2) - return Stream + from .streams import StreamWriter as sw + StreamWriter = sw + return StreamWriter raise AttributeError(f"module {__name__} has no attribute {name}") diff --git a/Lib/asyncio/streams.py b/Lib/asyncio/streams.py index e4ff259db9977c..6a26c70ef136e7 100644 --- a/Lib/asyncio/streams.py +++ b/Lib/asyncio/streams.py @@ -137,15 +137,12 @@ async def open_connection(host=None, port=None, *, """ if loop is None: loop = events.get_event_loop() - stream = Stream(mode=StreamMode.READWRITE, - limit=limit, - loop=loop, - _asyncio_internal=True) - await loop.create_connection( - lambda: _StreamProtocol(stream, loop=loop, - _asyncio_internal=True), - host, port, **kwds) - return stream, stream + reader = StreamReader(limit=limit, loop=loop) + protocol = StreamReaderProtocol(reader, loop=loop, _asyncio_internal=True) + transport, _ = await loop.create_connection( + lambda: protocol, host, port, **kwds) + writer = StreamWriter(transport, protocol, reader, loop) + return reader, writer async def start_server(client_connected_cb, host=None, port=None, *, @@ -175,10 +172,10 @@ async def start_server(client_connected_cb, host=None, port=None, *, loop = events.get_event_loop() def factory(): - protocol = _LegacyServerStreamProtocol(limit, - client_connected_cb, - loop=loop, - _asyncio_internal=True) + reader = StreamReader(limit=limit, loop=loop) + protocol = StreamReaderProtocol(reader, client_connected_cb, + loop=loop, + _asyncio_internal=True) return protocol return await loop.create_server(factory, host, port, **kwds) @@ -362,16 +359,13 @@ async def open_unix_connection(path=None, *, """Similar to `open_connection` but works with UNIX Domain Sockets.""" if loop is None: loop = events.get_event_loop() - stream = Stream(mode=StreamMode.READWRITE, - limit=limit, - loop=loop, - _asyncio_internal=True) - await loop.create_unix_connection( - lambda: _StreamProtocol(stream, - loop=loop, - _asyncio_internal=True), - path, **kwds) - return stream, stream + reader = StreamReader(limit=limit, loop=loop) + protocol = StreamReaderProtocol(reader, loop=loop, + _asyncio_internal=True) + transport, _ = await loop.create_unix_connection( + lambda: protocol, path, **kwds) + writer = StreamWriter(transport, protocol, reader, loop) + return reader, writer async def connect_unix(path=None, *, limit=_DEFAULT_LIMIT, @@ -403,10 +397,10 @@ async def start_unix_server(client_connected_cb, path=None, *, loop = events.get_event_loop() def factory(): - protocol = _LegacyServerStreamProtocol(limit, - client_connected_cb, - loop=loop, - _asyncio_internal=True) + reader = StreamReader(limit=limit, loop=loop) + protocol = StreamReaderProtocol(reader, client_connected_cb, + loop=loop, + _asyncio_internal=True) return protocol return await loop.create_unix_server(factory, path, **kwds) @@ -525,6 +519,500 @@ def _get_close_waiter(self, stream): raise NotImplementedError +# begin legacy stream APIs + +class StreamReaderProtocol(FlowControlMixin, protocols.Protocol): + """Helper class to adapt between Protocol and StreamReader. + + (This is a helper class instead of making StreamReader itself a + Protocol subclass, because the StreamReader has other potential + uses, and to prevent the user of the StreamReader to accidentally + call inappropriate methods of the protocol.) + """ + + def __init__(self, stream_reader, client_connected_cb=None, loop=None, + *, _asyncio_internal=False): + super().__init__(loop=loop, _asyncio_internal=_asyncio_internal) + self._stream_reader = stream_reader + self._stream_writer = None + self._client_connected_cb = client_connected_cb + self._over_ssl = False + self._closed = self._loop.create_future() + + def connection_made(self, transport): + self._stream_reader.set_transport(transport) + self._over_ssl = transport.get_extra_info('sslcontext') is not None + if self._client_connected_cb is not None: + self._stream_writer = StreamWriter(transport, self, + self._stream_reader, + self._loop) + res = self._client_connected_cb(self._stream_reader, + self._stream_writer) + if coroutines.iscoroutine(res): + self._loop.create_task(res) + + def connection_lost(self, exc): + if self._stream_reader is not None: + if exc is None: + self._stream_reader.feed_eof() + else: + self._stream_reader.set_exception(exc) + if not self._closed.done(): + if exc is None: + self._closed.set_result(None) + else: + self._closed.set_exception(exc) + super().connection_lost(exc) + self._stream_reader = None + self._stream_writer = None + + def data_received(self, data): + self._stream_reader.feed_data(data) + + def eof_received(self): + self._stream_reader.feed_eof() + if self._over_ssl: + # Prevent a warning in SSLProtocol.eof_received: + # "returning true from eof_received() + # has no effect when using ssl" + return False + return True + + def __del__(self): + # Prevent reports about unhandled exceptions. + # Better than self._closed._log_traceback = False hack + closed = self._closed + if closed.done() and not closed.cancelled(): + closed.exception() + + +class StreamWriter: + """Wraps a Transport. + + This exposes write(), writelines(), [can_]write_eof(), + get_extra_info() and close(). It adds drain() which returns an + optional Future on which you can wait for flow control. It also + adds a transport property which references the Transport + directly. + """ + + def __init__(self, transport, protocol, reader, loop): + self._transport = transport + self._protocol = protocol + # drain() expects that the reader has an exception() method + assert reader is None or isinstance(reader, StreamReader) + self._reader = reader + self._loop = loop + + def __repr__(self): + info = [self.__class__.__name__, f'transport={self._transport!r}'] + if self._reader is not None: + info.append(f'reader={self._reader!r}') + return '<{}>'.format(' '.join(info)) + + @property + def transport(self): + return self._transport + + def write(self, data): + self._transport.write(data) + + def writelines(self, data): + self._transport.writelines(data) + + def write_eof(self): + return self._transport.write_eof() + + def can_write_eof(self): + return self._transport.can_write_eof() + + def close(self): + return self._transport.close() + + def is_closing(self): + return self._transport.is_closing() + + async def wait_closed(self): + await self._protocol._closed + + def get_extra_info(self, name, default=None): + return self._transport.get_extra_info(name, default) + + async def drain(self): + """Flush the write buffer. + + The intended use is to write + + w.write(data) + await w.drain() + """ + if self._reader is not None: + exc = self._reader.exception() + if exc is not None: + raise exc + if self._transport.is_closing(): + # Yield to the event loop so connection_lost() may be + # called. Without this, _drain_helper() would return + # immediately, and code that calls + # write(...); await drain() + # in a loop would never call connection_lost(), so it + # would not see an error when the socket is closed. + await tasks.sleep(0, loop=self._loop) + await self._protocol._drain_helper() + + +class StreamReader: + + def __init__(self, limit=_DEFAULT_LIMIT, loop=None): + # The line length limit is a security feature; + # it also doubles as half the buffer limit. + + if limit <= 0: + raise ValueError('Limit cannot be <= 0') + + self._limit = limit + if loop is None: + self._loop = events.get_event_loop() + else: + self._loop = loop + self._buffer = bytearray() + self._eof = False # Whether we're done. + self._waiter = None # A future used by _wait_for_data() + self._exception = None + self._transport = None + self._paused = False + + def __repr__(self): + info = ['StreamReader'] + if self._buffer: + info.append(f'{len(self._buffer)} bytes') + if self._eof: + info.append('eof') + if self._limit != _DEFAULT_LIMIT: + info.append(f'limit={self._limit}') + if self._waiter: + info.append(f'waiter={self._waiter!r}') + if self._exception: + info.append(f'exception={self._exception!r}') + if self._transport: + info.append(f'transport={self._transport!r}') + if self._paused: + info.append('paused') + return '<{}>'.format(' '.join(info)) + + def exception(self): + return self._exception + + def set_exception(self, exc): + self._exception = exc + + waiter = self._waiter + if waiter is not None: + self._waiter = None + if not waiter.cancelled(): + waiter.set_exception(exc) + + def _wakeup_waiter(self): + """Wakeup read*() functions waiting for data or EOF.""" + waiter = self._waiter + if waiter is not None: + self._waiter = None + if not waiter.cancelled(): + waiter.set_result(None) + + def set_transport(self, transport): + assert self._transport is None, 'Transport already set' + self._transport = transport + + def _maybe_resume_transport(self): + if self._paused and len(self._buffer) <= self._limit: + self._paused = False + self._transport.resume_reading() + + def feed_eof(self): + self._eof = True + self._wakeup_waiter() + + def at_eof(self): + """Return True if the buffer is empty and 'feed_eof' was called.""" + return self._eof and not self._buffer + + def feed_data(self, data): + assert not self._eof, 'feed_data after feed_eof' + + if not data: + return + + self._buffer.extend(data) + self._wakeup_waiter() + + if (self._transport is not None and + not self._paused and + len(self._buffer) > 2 * self._limit): + try: + self._transport.pause_reading() + except NotImplementedError: + # The transport can't be paused. + # We'll just have to buffer all data. + # Forget the transport so we don't keep trying. + self._transport = None + else: + self._paused = True + + async def _wait_for_data(self, func_name): + """Wait until feed_data() or feed_eof() is called. + + If stream was paused, automatically resume it. + """ + # StreamReader uses a future to link the protocol feed_data() method + # to a read coroutine. Running two read coroutines at the same time + # would have an unexpected behaviour. It would not possible to know + # which coroutine would get the next data. + if self._waiter is not None: + raise RuntimeError( + f'{func_name}() called while another coroutine is ' + f'already waiting for incoming data') + + assert not self._eof, '_wait_for_data after EOF' + + # Waiting for data while paused will make deadlock, so prevent it. + # This is essential for readexactly(n) for case when n > self._limit. + if self._paused: + self._paused = False + self._transport.resume_reading() + + self._waiter = self._loop.create_future() + try: + await self._waiter + finally: + self._waiter = None + + async def readline(self): + """Read chunk of data from the stream until newline (b'\n') is found. + + On success, return chunk that ends with newline. If only partial + line can be read due to EOF, return incomplete line without + terminating newline. When EOF was reached while no bytes read, empty + bytes object is returned. + + If limit is reached, ValueError will be raised. In that case, if + newline was found, complete line including newline will be removed + from internal buffer. Else, internal buffer will be cleared. Limit is + compared against part of the line without newline. + + If stream was paused, this function will automatically resume it if + needed. + """ + sep = b'\n' + seplen = len(sep) + try: + line = await self.readuntil(sep) + except IncompleteReadError as e: + return e.partial + except LimitOverrunError as e: + if self._buffer.startswith(sep, e.consumed): + del self._buffer[:e.consumed + seplen] + else: + self._buffer.clear() + self._maybe_resume_transport() + raise ValueError(e.args[0]) + return line + + async def readuntil(self, separator=b'\n'): + """Read data from the stream until ``separator`` is found. + + On success, the data and separator will be removed from the + internal buffer (consumed). Returned data will include the + separator at the end. + + Configured stream limit is used to check result. Limit sets the + maximal length of data that can be returned, not counting the + separator. + + If an EOF occurs and the complete separator is still not found, + an IncompleteReadError exception will be raised, and the internal + buffer will be reset. The IncompleteReadError.partial attribute + may contain the separator partially. + + If the data cannot be read because of over limit, a + LimitOverrunError exception will be raised, and the data + will be left in the internal buffer, so it can be read again. + """ + seplen = len(separator) + if seplen == 0: + raise ValueError('Separator should be at least one-byte string') + + if self._exception is not None: + raise self._exception + + # Consume whole buffer except last bytes, which length is + # one less than seplen. Let's check corner cases with + # separator='SEPARATOR': + # * we have received almost complete separator (without last + # byte). i.e buffer='some textSEPARATO'. In this case we + # can safely consume len(separator) - 1 bytes. + # * last byte of buffer is first byte of separator, i.e. + # buffer='abcdefghijklmnopqrS'. We may safely consume + # everything except that last byte, but this require to + # analyze bytes of buffer that match partial separator. + # This is slow and/or require FSM. For this case our + # implementation is not optimal, since require rescanning + # of data that is known to not belong to separator. In + # real world, separator will not be so long to notice + # performance problems. Even when reading MIME-encoded + # messages :) + + # `offset` is the number of bytes from the beginning of the buffer + # where there is no occurrence of `separator`. + offset = 0 + + # Loop until we find `separator` in the buffer, exceed the buffer size, + # or an EOF has happened. + while True: + buflen = len(self._buffer) + + # Check if we now have enough data in the buffer for `separator` to + # fit. + if buflen - offset >= seplen: + isep = self._buffer.find(separator, offset) + + if isep != -1: + # `separator` is in the buffer. `isep` will be used later + # to retrieve the data. + break + + # see upper comment for explanation. + offset = buflen + 1 - seplen + if offset > self._limit: + raise LimitOverrunError( + 'Separator is not found, and chunk exceed the limit', + offset) + + # Complete message (with full separator) may be present in buffer + # even when EOF flag is set. This may happen when the last chunk + # adds data which makes separator be found. That's why we check for + # EOF *ater* inspecting the buffer. + if self._eof: + chunk = bytes(self._buffer) + self._buffer.clear() + raise IncompleteReadError(chunk, None) + + # _wait_for_data() will resume reading if stream was paused. + await self._wait_for_data('readuntil') + + if isep > self._limit: + raise LimitOverrunError( + 'Separator is found, but chunk is longer than limit', isep) + + chunk = self._buffer[:isep + seplen] + del self._buffer[:isep + seplen] + self._maybe_resume_transport() + return bytes(chunk) + + async def read(self, n=-1): + """Read up to `n` bytes from the stream. + + If n is not provided, or set to -1, read until EOF and return all read + bytes. If the EOF was received and the internal buffer is empty, return + an empty bytes object. + + If n is zero, return empty bytes object immediately. + + If n is positive, this function try to read `n` bytes, and may return + less or equal bytes than requested, but at least one byte. If EOF was + received before any byte is read, this function returns empty byte + object. + + Returned value is not limited with limit, configured at stream + creation. + + If stream was paused, this function will automatically resume it if + needed. + """ + + if self._exception is not None: + raise self._exception + + if n == 0: + return b'' + + if n < 0: + # This used to just loop creating a new waiter hoping to + # collect everything in self._buffer, but that would + # deadlock if the subprocess sends more than self.limit + # bytes. So just call self.read(self._limit) until EOF. + blocks = [] + while True: + block = await self.read(self._limit) + if not block: + break + blocks.append(block) + return b''.join(blocks) + + if not self._buffer and not self._eof: + await self._wait_for_data('read') + + # This will work right even if buffer is less than n bytes + data = bytes(self._buffer[:n]) + del self._buffer[:n] + + self._maybe_resume_transport() + return data + + async def readexactly(self, n): + """Read exactly `n` bytes. + + Raise an IncompleteReadError if EOF is reached before `n` bytes can be + read. The IncompleteReadError.partial attribute of the exception will + contain the partial read bytes. + + if n is zero, return empty bytes object. + + Returned value is not limited with limit, configured at stream + creation. + + If stream was paused, this function will automatically resume it if + needed. + """ + if n < 0: + raise ValueError('readexactly size can not be less than zero') + + if self._exception is not None: + raise self._exception + + if n == 0: + return b'' + + while len(self._buffer) < n: + if self._eof: + incomplete = bytes(self._buffer) + self._buffer.clear() + raise IncompleteReadError(incomplete, n) + + await self._wait_for_data('readexactly') + + if len(self._buffer) == n: + data = bytes(self._buffer) + self._buffer.clear() + else: + data = bytes(self._buffer[:n]) + del self._buffer[:n] + self._maybe_resume_transport() + return data + + def __aiter__(self): + return self + + async def __anext__(self): + val = await self.readline() + if val == b'': + raise StopAsyncIteration + return val + + +# end legacy stream APIs + + class _BaseStreamProtocol(FlowControlMixin, protocols.Protocol): """Helper class to adapt between Protocol and StreamReader. @@ -645,32 +1133,6 @@ def connection_lost(self, exc): self._stream_wr = None -class _LegacyServerStreamProtocol(_BaseStreamProtocol): - def __init__(self, limit, client_connected_cb, loop=None, - *, _asyncio_internal=False): - super().__init__(loop=loop, _asyncio_internal=_asyncio_internal) - self._client_connected_cb = client_connected_cb - self._limit = limit - - def connection_made(self, transport): - super().connection_made(transport) - stream = Stream(mode=StreamMode.READWRITE, - transport=transport, - protocol=self, - limit=self._limit, - loop=self._loop, - is_server_side=True, - _asyncio_internal=True) - self._stream = stream - res = self._client_connected_cb(self._stream, self._stream) - if coroutines.iscoroutine(res): - self._loop.create_task(res) - - def connection_lost(self, exc): - super().connection_lost(exc) - self._stream = None - - class _ServerStreamProtocol(_BaseStreamProtocol): def __init__(self, server, limit, client_connected_cb, loop=None, *, _asyncio_internal=False): diff --git a/Lib/test/test_asyncio/test_streams.py b/Lib/test/test_asyncio/test_streams.py index 1b91c2d7c155d5..9443b1549b408c 100644 --- a/Lib/test/test_asyncio/test_streams.py +++ b/Lib/test/test_asyncio/test_streams.py @@ -1029,27 +1029,27 @@ def test_del_stream_before_sock_closing(self): messages = [] self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) - with test_utils.run_test_server() as httpd: - rd, wr = self.loop.run_until_complete( - asyncio.open_connection(*httpd.address, loop=self.loop)) - sock = wr.get_extra_info('socket') - self.assertNotEqual(sock.fileno(), -1) + async def test(): - wr.write(b'GET / HTTP/1.0\r\n\r\n') - f = rd.readline() - data = self.loop.run_until_complete(f) - self.assertEqual(data, b'HTTP/1.0 200 OK\r\n') + with test_utils.run_test_server() as httpd: + stream = await asyncio.connect(*httpd.address) + sock = stream.get_extra_info('socket') + self.assertNotEqual(sock.fileno(), -1) - # drop refs to reader/writer - del rd - del wr - gc.collect() - # make a chance to close the socket - test_utils.run_briefly(self.loop) + await stream.write(b'GET / HTTP/1.0\r\n\r\n') + data = await stream.readline() + self.assertEqual(data, b'HTTP/1.0 200 OK\r\n') - self.assertEqual(1, len(messages), messages) - self.assertEqual(sock.fileno(), -1) + # drop refs to reader/writer + del stream + gc.collect() + # make a chance to close the socket + await asyncio.sleep(0) + self.assertEqual(1, len(messages), messages) + self.assertEqual(sock.fileno(), -1) + + self.loop.run_until_complete(test()) self.assertEqual(1, len(messages), messages) self.assertEqual('An open stream object is being garbage ' 'collected; call "stream.close()" explicitly.', @@ -1082,14 +1082,14 @@ def test_del_stream_before_connection_made(self): def test_async_writer_api(self): async def inner(httpd): - rd, wr = await asyncio.open_connection(*httpd.address) + stream = await asyncio.connect(*httpd.address) - await wr.write(b'GET / HTTP/1.0\r\n\r\n') - data = await rd.readline() + await stream.write(b'GET / HTTP/1.0\r\n\r\n') + data = await stream.readline() self.assertEqual(data, b'HTTP/1.0 200 OK\r\n') - data = await rd.read() + data = await stream.read() self.assertTrue(data.endswith(b'\r\n\r\nTest message')) - await wr.close() + await stream.close() messages = [] self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) @@ -1099,18 +1099,18 @@ async def inner(httpd): self.assertEqual(messages, []) - def test_async_writer_api(self): + def test_async_writer_api_exception_after_close(self): async def inner(httpd): - rd, wr = await asyncio.open_connection(*httpd.address) + stream = await asyncio.connect(*httpd.address) - await wr.write(b'GET / HTTP/1.0\r\n\r\n') - data = await rd.readline() + await stream.write(b'GET / HTTP/1.0\r\n\r\n') + data = await stream.readline() self.assertEqual(data, b'HTTP/1.0 200 OK\r\n') - data = await rd.read() + data = await stream.read() self.assertTrue(data.endswith(b'\r\n\r\nTest message')) - wr.close() + stream.close() with self.assertRaises(ConnectionResetError): - await wr.write(b'data') + await stream.write(b'data') messages = [] self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) @@ -1130,7 +1130,8 @@ def test_eof_feed_when_closing_writer(self): asyncio.open_connection(*httpd.address, loop=self.loop)) - f = wr.close() + wr.close() + f = wr.wait_closed() self.loop.run_until_complete(f) assert rd.at_eof() f = rd.read() From 25d59feb4dd61b38e737bcbf194849c208eed757 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Thu, 23 May 2019 14:07:39 +0300 Subject: [PATCH 57/82] Fix exception names --- Lib/asyncio/streams.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/Lib/asyncio/streams.py b/Lib/asyncio/streams.py index 6a26c70ef136e7..1790a821d69ce1 100644 --- a/Lib/asyncio/streams.py +++ b/Lib/asyncio/streams.py @@ -807,9 +807,9 @@ async def readline(self): seplen = len(sep) try: line = await self.readuntil(sep) - except IncompleteReadError as e: + except exceptions.IncompleteReadError as e: return e.partial - except LimitOverrunError as e: + except exceptions.LimitOverrunError as e: if self._buffer.startswith(sep, e.consumed): del self._buffer[:e.consumed + seplen] else: @@ -884,7 +884,7 @@ async def readuntil(self, separator=b'\n'): # see upper comment for explanation. offset = buflen + 1 - seplen if offset > self._limit: - raise LimitOverrunError( + raise exceptions.LimitOverrunError( 'Separator is not found, and chunk exceed the limit', offset) @@ -895,13 +895,13 @@ async def readuntil(self, separator=b'\n'): if self._eof: chunk = bytes(self._buffer) self._buffer.clear() - raise IncompleteReadError(chunk, None) + raise exception.IncompleteReadError(chunk, None) # _wait_for_data() will resume reading if stream was paused. await self._wait_for_data('readuntil') if isep > self._limit: - raise LimitOverrunError( + raise exceptions.LimitOverrunError( 'Separator is found, but chunk is longer than limit', isep) chunk = self._buffer[:isep + seplen] @@ -987,7 +987,7 @@ async def readexactly(self, n): if self._eof: incomplete = bytes(self._buffer) self._buffer.clear() - raise IncompleteReadError(incomplete, n) + raise exception.IncompleteReadError(incomplete, n) await self._wait_for_data('readexactly') From e07649851e963bd3e8a5547e976e1fc13e721df2 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Thu, 23 May 2019 15:16:05 +0300 Subject: [PATCH 58/82] Deprecate old streams --- Lib/asyncio/streams.py | 20 ++++ Lib/test/test_asyncio/test_streams.py | 135 +++++++++++++++++++++----- 2 files changed, 131 insertions(+), 24 deletions(-) diff --git a/Lib/asyncio/streams.py b/Lib/asyncio/streams.py index 1790a821d69ce1..f0fde245e721f6 100644 --- a/Lib/asyncio/streams.py +++ b/Lib/asyncio/streams.py @@ -135,6 +135,11 @@ async def open_connection(host=None, port=None, *, StreamReaderProtocol classes, just copy the code -- there's really nothing special here except some convenience.) """ + warnings.warn("open_connection() is deprecated since Python 3.8 " + "in favor of connect(), and scheduled for removal " + "in Python 3.10", + DeprecationWarning, + stacklevel=2) if loop is None: loop = events.get_event_loop() reader = StreamReader(limit=limit, loop=loop) @@ -168,6 +173,11 @@ async def start_server(client_connected_cb, host=None, port=None, *, The return value is the same as loop.create_server(), i.e. a Server object which can be used to stop the service. """ + warnings.warn("start_server() is deprecated since Python 3.8 " + "in favor of StreamServer(), and scheduled for removal " + "in Python 3.10", + DeprecationWarning, + stacklevel=2) if loop is None: loop = events.get_event_loop() @@ -357,6 +367,11 @@ def factory(): async def open_unix_connection(path=None, *, loop=None, limit=_DEFAULT_LIMIT, **kwds): """Similar to `open_connection` but works with UNIX Domain Sockets.""" + warnings.warn("open_unix_connection() is deprecated since Python 3.8 " + "in favor of connect_unix(), and scheduled for removal " + "in Python 3.10", + DeprecationWarning, + stacklevel=2) if loop is None: loop = events.get_event_loop() reader = StreamReader(limit=limit, loop=loop) @@ -393,6 +408,11 @@ async def connect_unix(path=None, *, async def start_unix_server(client_connected_cb, path=None, *, loop=None, limit=_DEFAULT_LIMIT, **kwds): """Similar to `start_server` but works with UNIX Domain Sockets.""" + warnings.warn("start_unix_server() is deprecated since Python 3.8 " + "in favor of UnixStreamServer(), and scheduled " + "for removal in Python 3.10", + DeprecationWarning, + stacklevel=2) if loop is None: loop = events.get_event_loop() diff --git a/Lib/test/test_asyncio/test_streams.py b/Lib/test/test_asyncio/test_streams.py index 9443b1549b408c..0cad81f5b85871 100644 --- a/Lib/test/test_asyncio/test_streams.py +++ b/Lib/test/test_asyncio/test_streams.py @@ -62,13 +62,15 @@ def tearDown(self): @mock.patch('asyncio.streams.events') def test_ctor_global_loop(self, m_events): - stream = asyncio.Stream(mode=asyncio.StreamMode.READ, _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + _asyncio_internal=True) self.assertIs(stream._loop, m_events.get_event_loop.return_value) def _basetest_open_connection(self, open_connection_fut): messages = [] self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) - reader, writer = self.loop.run_until_complete(open_connection_fut) + with self.assertWarns(DeprecationWarning): + reader, writer = self.loop.run_until_complete(open_connection_fut) writer.write(b'GET / HTTP/1.0\r\n\r\n') f = reader.readline() data = self.loop.run_until_complete(f) @@ -96,7 +98,9 @@ def _basetest_open_connection_no_loop_ssl(self, open_connection_fut): messages = [] self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) try: - reader, writer = self.loop.run_until_complete(open_connection_fut) + with self.assertWarns(DeprecationWarning): + reader, writer = self.loop.run_until_complete( + open_connection_fut) finally: asyncio.set_event_loop(None) writer.write(b'GET / HTTP/1.0\r\n\r\n') @@ -132,7 +136,8 @@ def test_open_unix_connection_no_loop_ssl(self): def _basetest_open_connection_error(self, open_connection_fut): messages = [] self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) - reader, writer = self.loop.run_until_complete(open_connection_fut) + with self.assertWarns(DeprecationWarning): + reader, writer = self.loop.run_until_complete(open_connection_fut) writer._protocol.connection_lost(ZeroDivisionError()) f = reader.read() with self.assertRaises(ZeroDivisionError): @@ -711,8 +716,9 @@ def stop(self): self.server = None async def client(addr): - reader, writer = await asyncio.open_connection( - *addr, loop=self.loop) + with self.assertWarns(DeprecationWarning): + reader, writer = await asyncio.open_connection( + *addr, loop=self.loop) # send a line writer.write(b"hello world!\n") # read it back @@ -726,7 +732,8 @@ async def client(addr): # test the server variant with a coroutine as client handler server = MyServer(self.loop) - addr = server.start() + with self.assertWarns(DeprecationWarning): + addr = server.start() msg = self.loop.run_until_complete(asyncio.Task(client(addr), loop=self.loop)) server.stop() @@ -734,7 +741,8 @@ async def client(addr): # test the server variant with a callback as client handler server = MyServer(self.loop) - addr = server.start_callback() + with self.assertWarns(DeprecationWarning): + addr = server.start_callback() msg = self.loop.run_until_complete(asyncio.Task(client(addr), loop=self.loop)) server.stop() @@ -782,8 +790,9 @@ def stop(self): self.server = None async def client(path): - reader, writer = await asyncio.open_unix_connection( - path, loop=self.loop) + with self.assertWarns(DeprecationWarning): + reader, writer = await asyncio.open_unix_connection( + path, loop=self.loop) # send a line writer.write(b"hello world!\n") # read it back @@ -798,7 +807,8 @@ async def client(path): # test the server variant with a coroutine as client handler with test_utils.unix_socket_path() as path: server = MyServer(self.loop, path) - server.start() + with self.assertWarns(DeprecationWarning): + server.start() msg = self.loop.run_until_complete(asyncio.Task(client(path), loop=self.loop)) server.stop() @@ -807,7 +817,8 @@ async def client(path): # test the server variant with a callback as client handler with test_utils.unix_socket_path() as path: server = MyServer(self.loop, path) - server.start_callback() + with self.assertWarns(DeprecationWarning): + server.start_callback() msg = self.loop.run_until_complete(asyncio.Task(client(path), loop=self.loop)) server.stop() @@ -877,7 +888,7 @@ def test_streamreaderprotocol_constructor(self): protocol = _StreamProtocol(stream, _asyncio_internal=True) self.assertIs(protocol._loop, self.loop) - def test_drain_raises(self): + def test_drain_raises_deprecated(self): # See http://bugs.python.org/issue25441 # This test should not use asyncio for the mock server; the @@ -898,8 +909,9 @@ def server(): clt.close() async def client(host, port): - reader, writer = await asyncio.open_connection( - host, port, loop=self.loop) + with self.assertWarns(DeprecationWarning): + reader, writer = await asyncio.open_connection( + host, port, loop=self.loop) while True: writer.write(b"foo\n") @@ -921,6 +933,49 @@ async def client(host, port): thread.join() self.assertEqual([], messages) + def test_drain_raises(self): + # See http://bugs.python.org/issue25441 + + # This test should not use asyncio for the mock server; the + # whole point of the test is to test for a bug in drain() + # where it never gives up the event loop but the socket is + # closed on the server side. + + messages = [] + self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) + q = queue.Queue() + + def server(): + # Runs in a separate thread. + with socket.create_server(('localhost', 0)) as sock: + addr = sock.getsockname() + q.put(addr) + clt, _ = sock.accept() + clt.close() + + async def client(host, port): + stream = await asyncio.connect(host, port) + + while True: + stream.write(b"foo\n") + await stream.drain() + + # Start the server thread and wait for it to be listening. + thread = threading.Thread(target=server) + thread.setDaemon(True) + thread.start() + addr = q.get() + + # Should not be stuck in an infinite loop. + with self.assertRaises((ConnectionResetError, ConnectionAbortedError, + BrokenPipeError)): + self.loop.run_until_complete(client(*addr)) + + # Clean up the thread. (Only on success; on failure, it may + # be stuck in accept().) + thread.join() + self.assertEqual([], messages) + def test___repr__(self): stream = asyncio.Stream(mode=asyncio.StreamMode.READ, loop=self.loop, @@ -996,10 +1051,11 @@ def test_LimitOverrunError_pickleable(self): self.assertEqual(str(e), str(e2)) self.assertEqual(e.consumed, e2.consumed) - def test_wait_closed_on_close(self): + def test_wait_closed_on_close_deprecated(self): with test_utils.run_test_server() as httpd: - rd, wr = self.loop.run_until_complete( - asyncio.open_connection(*httpd.address, loop=self.loop)) + with self.assertWarns(DeprecationWarning): + rd, wr = self.loop.run_until_complete( + asyncio.open_connection(*httpd.address, loop=self.loop)) wr.write(b'GET / HTTP/1.0\r\n\r\n') f = rd.readline() @@ -1013,10 +1069,28 @@ def test_wait_closed_on_close(self): self.assertTrue(wr.is_closing()) self.loop.run_until_complete(wr.wait_closed()) - def test_wait_closed_on_close_with_unread_data(self): + def test_wait_closed_on_close(self): with test_utils.run_test_server() as httpd: - rd, wr = self.loop.run_until_complete( - asyncio.open_connection(*httpd.address, loop=self.loop)) + stream = self.loop.run_until_complete( + asyncio.connect(*httpd.address)) + + stream.write(b'GET / HTTP/1.0\r\n\r\n') + f = stream.readline() + data = self.loop.run_until_complete(f) + self.assertEqual(data, b'HTTP/1.0 200 OK\r\n') + f = stream.read() + data = self.loop.run_until_complete(f) + self.assertTrue(data.endswith(b'\r\n\r\nTest message')) + self.assertFalse(stream.is_closing()) + stream.close() + self.assertTrue(stream.is_closing()) + self.loop.run_until_complete(stream.wait_closed()) + + def test_wait_closed_on_close_with_unread_data_deprecated(self): + with test_utils.run_test_server() as httpd: + with self.assertWarns(DeprecationWarning): + rd, wr = self.loop.run_until_complete( + asyncio.open_connection(*httpd.address, loop=self.loop)) wr.write(b'GET / HTTP/1.0\r\n\r\n') f = rd.readline() @@ -1025,6 +1099,18 @@ def test_wait_closed_on_close_with_unread_data(self): wr.close() self.loop.run_until_complete(wr.wait_closed()) + def test_wait_closed_on_close_with_unread_data(self): + with test_utils.run_test_server() as httpd: + stream = self.loop.run_until_complete( + asyncio.connect(*httpd.address)) + + stream.write(b'GET / HTTP/1.0\r\n\r\n') + f = stream.readline() + data = self.loop.run_until_complete(f) + self.assertEqual(data, b'HTTP/1.0 200 OK\r\n') + stream.close() + self.loop.run_until_complete(stream.wait_closed()) + def test_del_stream_before_sock_closing(self): messages = [] self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) @@ -1126,9 +1212,10 @@ def test_eof_feed_when_closing_writer(self): self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) with test_utils.run_test_server() as httpd: - rd, wr = self.loop.run_until_complete( - asyncio.open_connection(*httpd.address, - loop=self.loop)) + with self.assertWarns(DeprecationWarning): + rd, wr = self.loop.run_until_complete( + asyncio.open_connection(*httpd.address, + loop=self.loop)) wr.close() f = wr.wait_closed() From ef44960c7478c4e868dd1e3168647390af54e67b Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Thu, 23 May 2019 15:46:11 +0300 Subject: [PATCH 59/82] Fix tests --- Lib/asyncio/windows_events.py | 2 +- Lib/test/test_asyncio/test_streams.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/Lib/asyncio/windows_events.py b/Lib/asyncio/windows_events.py index 29750f18d80c46..b8d3a461a8a30d 100644 --- a/Lib/asyncio/windows_events.py +++ b/Lib/asyncio/windows_events.py @@ -605,7 +605,7 @@ async def connect_pipe(self, address): # ConnectPipe() failed with ERROR_PIPE_BUSY: retry later delay = min(delay * 2, CONNECT_PIPE_MAX_DELAY) - await tasks.sleep(delay, loop=self._loop) + await tasks.sleep(delay) return windows_utils.PipeHandle(handle) diff --git a/Lib/test/test_asyncio/test_streams.py b/Lib/test/test_asyncio/test_streams.py index 096b5da0a636ba..ff7cb3743dd5d4 100644 --- a/Lib/test/test_asyncio/test_streams.py +++ b/Lib/test/test_asyncio/test_streams.py @@ -891,7 +891,7 @@ def test_drain_raises(self): def server(): # Runs in a separate thread. - with socket.create_server(('localhost', 0)) as sock: + with socket.create_server(('127.0.0.1', 0)) as sock: addr = sock.getsockname() q.put(addr) clt, _ = sock.accept() @@ -1512,7 +1512,7 @@ async def do_connect(host, port): await stream.close() async def test(): - async with asyncio.StreamServer(serve_callback, 'localhost', 0) as srv: + async with asyncio.StreamServer(serve_callback, '127.0.0.1', 0) as srv: await srv.start_serving() await do_connect(*srv.addresses()[0]) From 504ea78b0d16f0ff2eb94bfbf9a31e753dfb2555 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Thu, 23 May 2019 15:58:10 +0300 Subject: [PATCH 60/82] Process deprecated StreamReader and StreamWriter in test_all --- Lib/test/test___all__.py | 38 +++++++++++++++++++++----------------- 1 file changed, 21 insertions(+), 17 deletions(-) diff --git a/Lib/test/test___all__.py b/Lib/test/test___all__.py index 116abca6c8a629..c077881511b8ce 100644 --- a/Lib/test/test___all__.py +++ b/Lib/test/test___all__.py @@ -30,23 +30,27 @@ def check_all(self, modname): raise NoAll(modname) names = {} with self.subTest(module=modname): - try: - exec("from %s import *" % modname, names) - except Exception as e: - # Include the module name in the exception string - self.fail("__all__ failure in {}: {}: {}".format( - modname, e.__class__.__name__, e)) - if "__builtins__" in names: - del names["__builtins__"] - if '__annotations__' in names: - del names['__annotations__'] - if "__warningregistry__" in names: - del names["__warningregistry__"] - keys = set(names) - all_list = sys.modules[modname].__all__ - all_set = set(all_list) - self.assertCountEqual(all_set, all_list, "in module {}".format(modname)) - self.assertEqual(keys, all_set, "in module {}".format(modname)) + with support.check_warnings( + ("", DeprecationWarning), + ("", ResourceWarning), + quiet=True): + try: + exec("from %s import *" % modname, names) + except Exception as e: + # Include the module name in the exception string + self.fail("__all__ failure in {}: {}: {}".format( + modname, e.__class__.__name__, e)) + if "__builtins__" in names: + del names["__builtins__"] + if '__annotations__' in names: + del names['__annotations__'] + if "__warningregistry__" in names: + del names["__warningregistry__"] + keys = set(names) + all_list = sys.modules[modname].__all__ + all_set = set(all_list) + self.assertCountEqual(all_set, all_list, "in module {}".format(modname)) + self.assertEqual(keys, all_set, "in module {}".format(modname)) def walk_modules(self, basedir, modpath): for fn in sorted(os.listdir(basedir)): From c1cc2427f8d04dfc38a71d8103f87c19fa5eb75d Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Thu, 23 May 2019 16:02:24 +0300 Subject: [PATCH 61/82] Deprecate StreamReaderProtocol import --- Lib/asyncio/__init__.py | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/Lib/asyncio/__init__.py b/Lib/asyncio/__init__.py index 00d8bd1ee3c68b..a6a29dbfecd507 100644 --- a/Lib/asyncio/__init__.py +++ b/Lib/asyncio/__init__.py @@ -46,24 +46,38 @@ __all__ += unix_events.__all__ -__all__ += ('StreamReader', 'StreamWriter') # deprecated +__all__ += ('StreamReader', 'StreamWriter', 'StreamReaderProtocol') # deprecated def __getattr__(name): - global StreamReader, StreamWriter + global StreamReader, StreamWriter, StreamReaderProtocol if name == 'StreamReader': - warnings.warn("StreamReader is deprecated, use asyncio.Stream instead", + warnings.warn("StreamReader is deprecated since Python 3.8 " + "in favor of Stream, and scheduled for removal " + "in Python 3.10", DeprecationWarning, stacklevel=2) from .streams import StreamReader as sr StreamReader = sr return StreamReader if name == 'StreamWriter': - warnings.warn("StreamWriter is deprecated, use asyncio.Stream instead", + warnings.warn("StreamWriter is deprecated since Python 3.8 " + "in favor of Stream, and scheduled for removal " + "in Python 3.10", DeprecationWarning, stacklevel=2) from .streams import StreamWriter as sw StreamWriter = sw return StreamWriter + if name == 'StreamReaderProtocol': + warnings.warn("Using asyncio internal class StreamReaderProtocol " + "is deprecated since Python 3.8 " + " and scheduled for removal " + "in Python 3.10", + DeprecationWarning, + stacklevel=2) + from .streams import StreamReaderProtocol as srp + StreamReaderProtocol = srp + return StreamReaderProtocol raise AttributeError(f"module {__name__} has no attribute {name}") From fd2a6eb71a0219f50469ab3cd4af16476459ee0d Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Thu, 23 May 2019 16:05:50 +0300 Subject: [PATCH 62/82] Fix module name --- Lib/asyncio/streams.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Lib/asyncio/streams.py b/Lib/asyncio/streams.py index f0fde245e721f6..83282cc9b8136e 100644 --- a/Lib/asyncio/streams.py +++ b/Lib/asyncio/streams.py @@ -915,7 +915,7 @@ async def readuntil(self, separator=b'\n'): if self._eof: chunk = bytes(self._buffer) self._buffer.clear() - raise exception.IncompleteReadError(chunk, None) + raise exceptions.IncompleteReadError(chunk, None) # _wait_for_data() will resume reading if stream was paused. await self._wait_for_data('readuntil') @@ -1007,7 +1007,7 @@ async def readexactly(self, n): if self._eof: incomplete = bytes(self._buffer) self._buffer.clear() - raise exception.IncompleteReadError(incomplete, n) + raise exceptions.IncompleteReadError(incomplete, n) await self._wait_for_data('readexactly') From b6fbc9f5d959827254f992a722eaddde01bb6e70 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Thu, 23 May 2019 16:29:28 +0300 Subject: [PATCH 63/82] Suppress deprecation warnings --- Lib/test/test_asyncio/test_base_events.py | 5 +-- Lib/test/test_asyncio/test_buffered_proto.py | 7 ++-- Lib/test/test_asyncio/test_server.py | 10 +++--- Lib/test/test_asyncio/test_sslproto.py | 37 +++++++++++--------- 4 files changed, 33 insertions(+), 26 deletions(-) diff --git a/Lib/test/test_asyncio/test_base_events.py b/Lib/test/test_asyncio/test_base_events.py index f068fc781f5d7c..68057667311fb7 100644 --- a/Lib/test/test_asyncio/test_base_events.py +++ b/Lib/test/test_asyncio/test_base_events.py @@ -1154,8 +1154,9 @@ def test_create_server_stream_bittype(self): @unittest.skipUnless(hasattr(socket, 'AF_INET6'), 'no IPv6 support') def test_create_server_ipv6(self): async def main(): - srv = await asyncio.start_server( - lambda: None, '::1', 0, loop=self.loop) + with self.assertWarns(DeprecationWarning): + srv = await asyncio.start_server( + lambda: None, '::1', 0, loop=self.loop) try: self.assertGreater(len(srv.sockets), 0) finally: diff --git a/Lib/test/test_asyncio/test_buffered_proto.py b/Lib/test/test_asyncio/test_buffered_proto.py index f24e363ebfcfa3..b1531fb9343f5e 100644 --- a/Lib/test/test_asyncio/test_buffered_proto.py +++ b/Lib/test/test_asyncio/test_buffered_proto.py @@ -58,9 +58,10 @@ async def on_server_client(reader, writer): writer.close() await writer.wait_closed() - srv = self.loop.run_until_complete( - asyncio.start_server( - on_server_client, '127.0.0.1', 0)) + with self.assertWarns(DeprecationWarning): + srv = self.loop.run_until_complete( + asyncio.start_server( + on_server_client, '127.0.0.1', 0)) addr = srv.sockets[0].getsockname() self.loop.run_until_complete( diff --git a/Lib/test/test_asyncio/test_server.py b/Lib/test/test_asyncio/test_server.py index ab7f3debbc152e..235bfdc6a952f3 100644 --- a/Lib/test/test_asyncio/test_server.py +++ b/Lib/test/test_asyncio/test_server.py @@ -46,8 +46,9 @@ async def main(srv): async with srv: await srv.serve_forever() - srv = self.loop.run_until_complete(asyncio.start_server( - serve, support.HOSTv4, 0, loop=self.loop, start_serving=False)) + with self.assertWarns(DeprecationWarning): + srv = self.loop.run_until_complete(asyncio.start_server( + serve, support.HOSTv4, 0, loop=self.loop, start_serving=False)) self.assertFalse(srv.is_serving()) @@ -102,8 +103,9 @@ async def main(srv): await srv.serve_forever() with test_utils.unix_socket_path() as addr: - srv = self.loop.run_until_complete(asyncio.start_unix_server( - serve, addr, loop=self.loop, start_serving=False)) + with self.assertWarns(DeprecationWarning): + srv = self.loop.run_until_complete(asyncio.start_unix_server( + serve, addr, loop=self.loop, start_serving=False)) main_task = self.loop.create_task(main(srv)) diff --git a/Lib/test/test_asyncio/test_sslproto.py b/Lib/test/test_asyncio/test_sslproto.py index 079b25585566b1..4215abf5d8630b 100644 --- a/Lib/test/test_asyncio/test_sslproto.py +++ b/Lib/test/test_asyncio/test_sslproto.py @@ -649,12 +649,13 @@ def server(sock): sock.close() async def client(addr): - reader, writer = await asyncio.open_connection( - *addr, - ssl=client_sslctx, - server_hostname='', - loop=self.loop, - ssl_handshake_timeout=1.0) + with self.assertWarns(DeprecationWarning): + reader, writer = await asyncio.open_connection( + *addr, + ssl=client_sslctx, + server_hostname='', + loop=self.loop, + ssl_handshake_timeout=1.0) with self.tcp_server(server, max_clients=1, @@ -688,12 +689,13 @@ def server(sock): sock.close() async def client(addr): - reader, writer = await asyncio.open_connection( - *addr, - ssl=client_sslctx, - server_hostname='', - loop=self.loop, - ssl_handshake_timeout=1.0) + with self.assertWarns(DeprecationWarning): + reader, writer = await asyncio.open_connection( + *addr, + ssl=client_sslctx, + server_hostname='', + loop=self.loop, + ssl_handshake_timeout=1.0) with self.tcp_server(server, max_clients=1, @@ -724,11 +726,12 @@ def server(sock): sock.close() async def client(addr): - reader, writer = await asyncio.open_connection( - *addr, - ssl=client_sslctx, - server_hostname='', - loop=self.loop) + with self.assertWarns(DeprecationWarning): + reader, writer = await asyncio.open_connection( + *addr, + ssl=client_sslctx, + server_hostname='', + loop=self.loop) self.assertEqual(await reader.readline(), b'A\n') writer.write(b'B') From c190c7260a91c2aeeac0f7961cdf238795988455 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Thu, 23 May 2019 16:34:50 +0300 Subject: [PATCH 64/82] Fix another warning --- Lib/test/test_asyncio/test_windows_events.py | 1 + 1 file changed, 1 insertion(+) diff --git a/Lib/test/test_asyncio/test_windows_events.py b/Lib/test/test_asyncio/test_windows_events.py index c7d7bf433d8a8d..13aef7cf1f776b 100644 --- a/Lib/test/test_asyncio/test_windows_events.py +++ b/Lib/test/test_asyncio/test_windows_events.py @@ -119,6 +119,7 @@ async def _test_pipe(self): response = await r.readline() self.assertEqual(response, 'LOWER-{}\n'.format(i).encode()) w.close() + await r.close() server.close() From d80cb9c67f2ba608b0aca56d6208a57aa0c54a68 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Fri, 24 May 2019 12:22:58 +0300 Subject: [PATCH 65/82] listeners -> sockets --- Lib/asyncio/streams.py | 61 +++------------------------ Lib/test/test_asyncio/test_streams.py | 22 +++++----- 2 files changed, 16 insertions(+), 67 deletions(-) diff --git a/Lib/asyncio/streams.py b/Lib/asyncio/streams.py index 83282cc9b8136e..a7c14277a0c35c 100644 --- a/Lib/asyncio/streams.py +++ b/Lib/asyncio/streams.py @@ -1,5 +1,5 @@ __all__ = ( - 'Stream', 'StreamMode', 'Listener', + 'Stream', 'StreamMode', 'open_connection', 'start_server', 'connect', 'StreamServer') @@ -40,57 +40,6 @@ def _check_write(self): if not self & self.WRITE: raise RuntimeError("The stream is read-only") -class Listener: - # The class represents listening sockets served by - # StreamServer and UnixStreamServer - # The exposed API is IPv4/IPv6/UNIX address itself - # and socket fileno - # The socket object is not exposed intentionally - # to prevent wild access from high-level API - # sock = socket.fromfd(fileno, family, type, proto) - # still can be used if the hacking is really necessary - - __slots__ = ('_family', '_type', '_proto', '_addr', '_fileno') - - def __init__(self, sock, _asyncio_internal=False): - if not _asyncio_internal: - raise TypeError(f"{self.__class__} should be instaniated " - "by asyncio internals only") - self._family = sock.family - self._type = sock.type - self._proto = sock.proto - self._addr = sock.getsockname() - self._fileno = sock.fileno() - - @property - def family(self): - return self._family - - @property - def type(self): - return self._type - - @property - def proto(self): - return self._proto - - @property - def addr(self): - return self._addr - - def fileno(self): - return self._fileno - - def __repr__(self): - ret = [f"<{self.__class__.__name__}"] - ret.append(f"family={self.family}") - ret.append(f"type={self.type}") - ret.append(f"proto={self.proto}") - ret.append(f"addr={self.addr}") - ret.append(f"fileno={self.fileno()}") - ret.append(">") - return " ".join(ret) - async def connect(host=None, port=None, *, limit=_DEFAULT_LIMIT, @@ -225,12 +174,12 @@ async def bind(self): def is_bound(self): return self._low_server is not None - def listeners(self): + @property + def sockets(self): # multiple value for socket bound to both IPv4 and IPv6 families if self._low_server is None: - return tuple() - return tuple(Listener(sock, _asyncio_internal=True) - for sock in self._low_server.sockets) + return [] + return self._low_server.sockets def is_serving(self): if self._low_server is None: diff --git a/Lib/test/test_asyncio/test_streams.py b/Lib/test/test_asyncio/test_streams.py index 0746f254b6c525..6cef34230c19df 100644 --- a/Lib/test/test_asyncio/test_streams.py +++ b/Lib/test/test_asyncio/test_streams.py @@ -1313,7 +1313,7 @@ async def handle_client(stream): await stream.close() async def client(srv): - addr = srv.listeners()[0].addr + addr = srv.sockets[0].getsockname() stream = await asyncio.connect(*addr) # send a line await stream.write(b"hello world!\n") @@ -1345,7 +1345,7 @@ async def handle_client(stream): await stream.close() async def client(srv): - addr = srv.listeners()[0].addr + addr = srv.sockets[0].getsockname() stream = await asyncio.connect_unix(addr) # send a line await stream.write(b"hello world!\n") @@ -1387,13 +1387,13 @@ async def handle_client(stream): async def test(): srv = asyncio.StreamServer(handle_client, '127.0.0.1', 0) self.assertFalse(srv.is_bound()) - self.assertEqual(0, len(srv.listeners())) + self.assertEqual(0, len(srv.sockets)) await srv.bind() self.assertTrue(srv.is_bound()) - self.assertEqual(1, len(srv.listeners())) + self.assertEqual(1, len(srv.sockets)) await srv.close() self.assertFalse(srv.is_bound()) - self.assertEqual(0, len(srv.listeners())) + self.assertEqual(0, len(srv.sockets)) messages = [] self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) @@ -1407,7 +1407,7 @@ async def handle_client(stream): async def test(): async with asyncio.StreamServer(handle_client, '127.0.0.1', 0) as srv: self.assertTrue(srv.is_bound()) - self.assertEqual(1, len(srv.listeners())) + self.assertEqual(1, len(srv.sockets)) messages = [] self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) @@ -1442,7 +1442,7 @@ async def handle_client(stream): server_stream_aborted = True async def client(srv): - addr = srv.listeners()[0].addr + addr = srv.sockets[0].getsockname() stream = await asyncio.connect(*addr) fut.set_result(None) self.assertEqual(b'', await stream.readline()) @@ -1474,7 +1474,7 @@ async def handle_client(stream): server_stream_aborted = True async def client(srv): - addr = srv.listeners()[0].addr + addr = srv.sockets[0].getsockname() stream = await asyncio.connect(*addr) fut.set_result(None) self.assertEqual(b'', await stream.readline()) @@ -1504,7 +1504,7 @@ async def handle_client(stream): await asyncio.sleep(0.01) async def client(srv): - addr = srv.listeners()[0].addr + addr = srv.sockets[0].getsockname() stream = await asyncio.connect(*addr) fut1.set_result(None) await fut2 @@ -1541,7 +1541,7 @@ async def handle_client(stream): await asyncio.sleep(0.01) async def client(srv): - addr = srv.listeners()[0].addr + addr = srv.sockets[0].getsockname() stream = await asyncio.connect(*addr) fut1.set_result(None) await fut2 @@ -1602,7 +1602,7 @@ async def do_connect(host, port): async def test(): async with asyncio.StreamServer(serve_callback, '127.0.0.1', 0) as srv: await srv.start_serving() - await do_connect(*srv.listeners()[0].addr) + await do_connect(*srv.sockets[0].getsockname()) self.loop.run_until_complete(test()) From 1ae7d8dda62975c2820bbc4c9f9386efe5e1956b Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Fri, 24 May 2019 12:57:14 +0300 Subject: [PATCH 66/82] Add repr --- Lib/asyncio/streams.py | 13 +++++++--- Lib/test/test_asyncio/test_streams.py | 35 +++++++++++++++++++++++++++ 2 files changed, 45 insertions(+), 3 deletions(-) diff --git a/Lib/asyncio/streams.py b/Lib/asyncio/streams.py index a7c14277a0c35c..9dec696386add3 100644 --- a/Lib/asyncio/streams.py +++ b/Lib/asyncio/streams.py @@ -141,8 +141,6 @@ def factory(): class _BaseStreamServer: - # TODO: add __repr__ - # Design notes. # StreamServer and UnixStreamServer are exposed as FINAL classes, # not function factories. @@ -152,6 +150,8 @@ class _BaseStreamServer: # The class doesn't provide API for enumerating connected streams # It can be a subject for improvements in Python 3.9 + _low_server = None + def __init__(self, client_connected_cb, /, limit=_DEFAULT_LIMIT, @@ -162,7 +162,6 @@ def __init__(self, client_connected_cb, self._client_connected_cb = client_connected_cb self._limit = limit self._loop = events.get_running_loop() - self._low_server = None self._streams = {} self._shutdown_timeout = shutdown_timeout @@ -256,6 +255,14 @@ async def _shutdown_active_tasks(self, active_tasks): "stream_server": self }) + def __repr__(self): + ret = [f'{self.__class__.__name__}'] + if self.is_serving(): + ret.append('serving') + if self.sockets: + ret.append(f'sockets={self.sockets!r}') + return '<' + ' '.join(ret) + '>' + def __del__(self, _warn=warnings.warn): if self._low_server is not None: _warn(f"unclosed stream server {self!r}", diff --git a/Lib/test/test_asyncio/test_streams.py b/Lib/test/test_asyncio/test_streams.py index 6cef34230c19df..12ef1f5ec266ac 100644 --- a/Lib/test/test_asyncio/test_streams.py +++ b/Lib/test/test_asyncio/test_streams.py @@ -1622,6 +1622,41 @@ def test_connect_start_tls(self): sslcontext=test_utils.dummy_ssl_context())) self._basetest_connect(stream) + def test_repr_unbound(self): + async def serve(stream): + pass + + async def test(): + srv = asyncio.StreamServer(serve) + self.assertEqual('', repr(srv)) + await srv.close() + + self.loop.run_until_complete(test()) + + def test_repr_bound(self): + async def serve(stream): + pass + + async def test(): + srv = asyncio.StreamServer(serve, '127.0.0.1', 0) + await srv.bind() + self.assertRegex(repr(srv), r'') + await srv.close() + + self.loop.run_until_complete(test()) + + def test_repr_serving(self): + async def serve(stream): + pass + + async def test(): + srv = asyncio.StreamServer(serve, '127.0.0.1', 0) + await srv.start_serving() + self.assertRegex(repr(srv), r'') + await srv.close() + + self.loop.run_until_complete(test()) + if __name__ == '__main__': unittest.main() From 9d34da8ecf2aa3e509a7b1618299f59f4abf5052 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Fri, 24 May 2019 12:59:28 +0300 Subject: [PATCH 67/82] Fix comment --- Lib/asyncio/streams.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/Lib/asyncio/streams.py b/Lib/asyncio/streams.py index 9dec696386add3..11b142d1fb0d40 100644 --- a/Lib/asyncio/streams.py +++ b/Lib/asyncio/streams.py @@ -165,6 +165,11 @@ def __init__(self, client_connected_cb, self._streams = {} self._shutdown_timeout = shutdown_timeout + def __init_subclass__(cls): + if not cls.__module__.startswith('asyncio.'): + raise TypeError(f"asyncio.{cls.__name__} " + "class cannot be inherited from") + async def bind(self): if self._low_server is not None: return @@ -224,11 +229,6 @@ async def __aenter__(self): async def __aexit__(self, exc_type, exc_value, exc_tb): await self.close() - def __init_subclass__(cls): - if not cls.__module__.startswith('asyncio.'): - raise TypeError(f"asyncio.{cls.__name__} " - "class cannot be inherited from") - def _attach(self, stream, task): self._streams[stream] = task From d0c9adadc322c9d4d2f4d67c083d71d6e630bdbe Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Fri, 24 May 2019 13:01:24 +0300 Subject: [PATCH 68/82] Fix comment --- Lib/asyncio/streams.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Lib/asyncio/streams.py b/Lib/asyncio/streams.py index 11b142d1fb0d40..a91812e12c8224 100644 --- a/Lib/asyncio/streams.py +++ b/Lib/asyncio/streams.py @@ -1146,8 +1146,8 @@ def _swallow_unhandled_exception(task): # if stream.write() was used without await and # stream.drain() was paused and resumed with an exception - # TODO: add if not task.cancelled() check!!!! - task.exception() + if not task.cancelled(): + task.exception() class Stream: From af5f27e789f9960bb01612890e8b14b09ca61a4f Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Fri, 24 May 2019 13:22:07 +0300 Subject: [PATCH 69/82] Implement async with stream: ... --- Lib/asyncio/streams.py | 10 ++++------ Lib/test/test_asyncio/test_streams.py | 14 ++++++++++++++ 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/Lib/asyncio/streams.py b/Lib/asyncio/streams.py index a91812e12c8224..db3f9e12e347aa 100644 --- a/Lib/asyncio/streams.py +++ b/Lib/asyncio/streams.py @@ -1160,8 +1160,6 @@ class Stream: directly. """ - # TODO: add __aenter__ / __aexit__ to close stream - _source_traceback = None def __init__(self, mode, *, @@ -1648,8 +1646,8 @@ async def __anext__(self): raise StopAsyncIteration return val - # async def __aenter__(self): - # return self + async def __aenter__(self): + return self - # async def __aexit__(self, exc_type, exc_val, exc_tb): - # await self.close() + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.close() diff --git a/Lib/test/test_asyncio/test_streams.py b/Lib/test/test_asyncio/test_streams.py index 12ef1f5ec266ac..6f1be4f6506c68 100644 --- a/Lib/test/test_asyncio/test_streams.py +++ b/Lib/test/test_asyncio/test_streams.py @@ -1304,6 +1304,20 @@ def test_connect_unix(self): asyncio.connect_unix(httpd.address)) self._basetest_connect(stream) + def test_stream_async_context_manager(self): + async def test(httpd): + stream = await asyncio.connect(*httpd.address) + async with stream: + await stream.write(b'GET / HTTP/1.0\r\n\r\n') + data = await stream.readline() + self.assertEqual(data, b'HTTP/1.0 200 OK\r\n') + data = await stream.read() + self.assertTrue(data.endswith(b'\r\n\r\nTest message')) + self.assertTrue(stream.is_closing()) + + with test_utils.run_test_server() as httpd: + self.loop.run_until_complete(test(httpd)) + def test_stream_server(self): async def handle_client(stream): From a734f06078ab997912d469c6088e39c4b6dbf89f Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Fri, 24 May 2019 13:24:17 +0300 Subject: [PATCH 70/82] _low_server -> _server_impl --- Lib/asyncio/streams.py | 40 ++++++++++++++++++++-------------------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/Lib/asyncio/streams.py b/Lib/asyncio/streams.py index db3f9e12e347aa..29c7a92d40105d 100644 --- a/Lib/asyncio/streams.py +++ b/Lib/asyncio/streams.py @@ -150,7 +150,7 @@ class _BaseStreamServer: # The class doesn't provide API for enumerating connected streams # It can be a subject for improvements in Python 3.9 - _low_server = None + _server_impl = None def __init__(self, client_connected_cb, /, @@ -171,55 +171,55 @@ def __init_subclass__(cls): "class cannot be inherited from") async def bind(self): - if self._low_server is not None: + if self._server_impl is not None: return - self._low_server = await self._bind() + self._server_impl = await self._bind() def is_bound(self): - return self._low_server is not None + return self._server_impl is not None @property def sockets(self): # multiple value for socket bound to both IPv4 and IPv6 families - if self._low_server is None: + if self._server_impl is None: return [] - return self._low_server.sockets + return self._server_impl.sockets def is_serving(self): - if self._low_server is None: + if self._server_impl is None: return False - return self._low_server.is_serving() + return self._server_impl.is_serving() async def start_serving(self): await self.bind() - await self._low_server.start_serving() + await self._server_impl.start_serving() async def serve_forever(self): await self.start_serving() - await self._low_server.serve_forever() + await self._server_impl.serve_forever() async def close(self): - if self._low_server is None: + if self._server_impl is None: return - self._low_server.close() + self._server_impl.close() streams = list(self._streams.keys()) active_tasks = list(self._streams.values()) if streams: await tasks.wait([stream.close() for stream in streams]) - await self._low_server.wait_closed() - self._low_server = None + await self._server_impl.wait_closed() + self._server_impl = None await self._shutdown_active_tasks(active_tasks) async def abort(self): - if self._low_server is None: + if self._server_impl is None: return - self._low_server.close() + self._server_impl.close() streams = list(self._streams.keys()) active_tasks = list(self._streams.values()) if streams: await tasks.wait([stream.abort() for stream in streams]) - await self._low_server.wait_closed() - self._low_server = None + await self._server_impl.wait_closed() + self._server_impl = None await self._shutdown_active_tasks(active_tasks) async def __aenter__(self): @@ -264,10 +264,10 @@ def __repr__(self): return '<' + ' '.join(ret) + '>' def __del__(self, _warn=warnings.warn): - if self._low_server is not None: + if self._server_impl is not None: _warn(f"unclosed stream server {self!r}", ResourceWarning, source=self) - self._low_server.close() + self._server_impl.close() class StreamServer(_BaseStreamServer): From e01b3945ce96aa1d2589201220647607f5df028a Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Fri, 24 May 2019 14:19:01 +0300 Subject: [PATCH 71/82] Support async with asyncio.connect(): ... --- Lib/asyncio/streams.py | 82 +++++++++++++++++++++++---- Lib/test/test_asyncio/test_streams.py | 27 +++++++++ 2 files changed, 97 insertions(+), 12 deletions(-) diff --git a/Lib/asyncio/streams.py b/Lib/asyncio/streams.py index 29c7a92d40105d..2e47726b14abf9 100644 --- a/Lib/asyncio/streams.py +++ b/Lib/asyncio/streams.py @@ -41,13 +41,53 @@ def _check_write(self): raise RuntimeError("The stream is read-only") -async def connect(host=None, port=None, *, - limit=_DEFAULT_LIMIT, - ssl=None, family=0, proto=0, - flags=0, sock=None, local_addr=None, - server_hostname=None, - ssl_handshake_timeout=None, - happy_eyeballs_delay=None, interleave=None): +class _ContextManagerHelper: + __slots__ = ('_awaitable', '_result') + + def __init__(self, awaitable): + self._awaitable = awaitable + self._result = None + + def __await__(self): + return self._awaitable.__await__() + + async def __aenter__(self): + ret = await self._awaitable + result = await ret.__aenter__() + self._result = result + return result + + async def __aexit__(self, exc_type, exc_val, exc_tb): + return await self._result.__aexit__(exc_type, exc_val, exc_tb) + + +def connect(host=None, port=None, *, + limit=_DEFAULT_LIMIT, + ssl=None, family=0, proto=0, + flags=0, sock=None, local_addr=None, + server_hostname=None, + ssl_handshake_timeout=None, + happy_eyeballs_delay=None, interleave=None): + # Design note: + # Don't use decorator approach but exilicit non-async + # function to fail fast and explicitly + # if passed arguments don't match the function signature + return _ContextManagerHelper(_connect(host, port, limit, + ssl, family, proto, + flags, sock, local_addr, + server_hostname, + ssl_handshake_timeout, + happy_eyeballs_delay, + interleave)) + + +async def _connect(host, port, + limit, + ssl, family, proto, + flags, sock, local_addr, + server_hostname, + ssl_handshake_timeout, + happy_eyeballs_delay, interleave): loop = events.get_running_loop() stream = Stream(mode=StreamMode.READWRITE, limit=limit, @@ -338,11 +378,29 @@ async def open_unix_connection(path=None, *, writer = StreamWriter(transport, protocol, reader, loop) return reader, writer - async def connect_unix(path=None, *, - limit=_DEFAULT_LIMIT, - ssl=None, sock=None, - server_hostname=None, - ssl_handshake_timeout=None): + + def connect_unix(path=None, *, + limit=_DEFAULT_LIMIT, + ssl=None, sock=None, + server_hostname=None, + ssl_handshake_timeout=None): + """Similar to `connect()` but works with UNIX Domain Sockets.""" + # Design note: + # Don't use decorator approach but exilicit non-async + # function to fail fast and explicitly + # if passed arguments don't match the function signature + return _ContextManagerHelper(_connect_unix(path, + limit, + ssl, sock, + server_hostname, + ssl_handshake_timeout)) + + + async def _connect_unix(path, + limit, + ssl, sock, + server_hostname, + ssl_handshake_timeout): """Similar to `connect()` but works with UNIX Domain Sockets.""" loop = events.get_running_loop() stream = Stream(mode=StreamMode.READWRITE, diff --git a/Lib/test/test_asyncio/test_streams.py b/Lib/test/test_asyncio/test_streams.py index 6f1be4f6506c68..d0b7fef8ea52cb 100644 --- a/Lib/test/test_asyncio/test_streams.py +++ b/Lib/test/test_asyncio/test_streams.py @@ -1318,6 +1318,33 @@ async def test(httpd): with test_utils.run_test_server() as httpd: self.loop.run_until_complete(test(httpd)) + def test_connect_async_context_manager(self): + async def test(httpd): + async with asyncio.connect(*httpd.address) as stream: + await stream.write(b'GET / HTTP/1.0\r\n\r\n') + data = await stream.readline() + self.assertEqual(data, b'HTTP/1.0 200 OK\r\n') + data = await stream.read() + self.assertTrue(data.endswith(b'\r\n\r\nTest message')) + self.assertTrue(stream.is_closing()) + + with test_utils.run_test_server() as httpd: + self.loop.run_until_complete(test(httpd)) + + @support.skip_unless_bind_unix_socket + def test_connect_unix_async_context_manager(self): + async def test(httpd): + async with asyncio.connect_unix(httpd.address) as stream: + await stream.write(b'GET / HTTP/1.0\r\n\r\n') + data = await stream.readline() + self.assertEqual(data, b'HTTP/1.0 200 OK\r\n') + data = await stream.read() + self.assertTrue(data.endswith(b'\r\n\r\nTest message')) + self.assertTrue(stream.is_closing()) + + with test_utils.run_test_unix_server() as httpd: + self.loop.run_until_complete(test(httpd)) + def test_stream_server(self): async def handle_client(stream): From 0ef2195d4fd0f7bdd8d402a86d83906d632814e1 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Fri, 24 May 2019 14:44:36 +0300 Subject: [PATCH 72/82] Use _OptionalAwait --- Lib/asyncio/streams.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/Lib/asyncio/streams.py b/Lib/asyncio/streams.py index 2e47726b14abf9..66d6239ab312f3 100644 --- a/Lib/asyncio/streams.py +++ b/Lib/asyncio/streams.py @@ -1199,13 +1199,20 @@ def connection_lost(self, exc): self._stream = None -def _swallow_unhandled_exception(task): - # Do a trick to suppress unhandled exception - # if stream.write() was used without await and - # stream.drain() was paused and resumed with an exception +class _OptionalAwait: + # The class doesn't create a coroutine + # if not awaited + # It prevents "coroutine is never awaited" message - if not task.cancelled(): - task.exception() + __slot___ = ('_method', '_args', '_kwargs') + + def __init__(self, method, *args, **kwargs): + self._method = method + self._args = args + self._kwargs = kwargs + + def __await__(self): + return self._method(*self._args, *self._kwargs).__await__() class Stream: @@ -1318,9 +1325,7 @@ def _fast_drain(self): # fast path, the stream is not paused # no need to wait for resume signal return self._complete_fut - ret = self._loop.create_task(self.drain()) - ret.add_done_callback(_swallow_unhandled_exception) - return ret + return _OptionalAwait(self.drain) def write_eof(self): self._mode._check_write() From 75804f622b8791c5f8077033e3ba356b05dd6fa5 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Sun, 26 May 2019 00:56:39 +0300 Subject: [PATCH 73/82] Imprement connect_read_pipe and connect_write_pipe stream factories --- Lib/asyncio/streams.py | 46 +++++++++++++++++++++++-- Lib/test/test_asyncio/test_streams.py | 49 ++++++++++++++++++++++++--- 2 files changed, 88 insertions(+), 7 deletions(-) diff --git a/Lib/asyncio/streams.py b/Lib/asyncio/streams.py index 66d6239ab312f3..01a5abe888c67d 100644 --- a/Lib/asyncio/streams.py +++ b/Lib/asyncio/streams.py @@ -1,7 +1,7 @@ __all__ = ( 'Stream', 'StreamMode', 'open_connection', 'start_server', - 'connect', + 'connect', 'connect_read_pipe', 'connect_write_pipe', 'StreamServer') import enum @@ -105,6 +105,48 @@ async def _connect(host, port, return stream +def connect_read_pipe(pipe, *, limit=_DEFAULT_LIMIT): + # Design note: + # Don't use decorator approach but exilicit non-async + # function to fail fast and explicitly + # if passed arguments don't match the function signature + return _ContextManagerHelper(_connect_read_pipe(pipe, limit)) + + +async def _connect_read_pipe(pipe, limit): + loop = events.get_running_loop() + stream = Stream(mode=StreamMode.READ, + limit=limit, + loop=loop, + _asyncio_internal=True) + await loop.connect_read_pipe( + lambda: _StreamProtocol(stream, loop=loop, + _asyncio_internal=True), + pipe) + return stream + + +def connect_write_pipe(pipe, *, limit=_DEFAULT_LIMIT): + # Design note: + # Don't use decorator approach but exilicit non-async + # function to fail fast and explicitly + # if passed arguments don't match the function signature + return _ContextManagerHelper(_connect_write_pipe(pipe, limit)) + + +async def _connect_write_pipe(pipe, limit): + loop = events.get_running_loop() + stream = Stream(mode=StreamMode.WRITE, + limit=limit, + loop=loop, + _asyncio_internal=True) + await loop.connect_write_pipe( + lambda: _StreamProtocol(stream, loop=loop, + _asyncio_internal=True), + pipe) + return stream + + async def open_connection(host=None, port=None, *, loop=None, limit=_DEFAULT_LIMIT, **kwds): """A wrapper for create_connection() returning a (reader, writer) pair. @@ -1422,13 +1464,11 @@ def _maybe_resume_transport(self): self._transport.resume_reading() def feed_eof(self): - self._mode._check_read() self._eof = True self._wakeup_waiter() def at_eof(self): """Return True if the buffer is empty and 'feed_eof' was called.""" - self._mode._check_read() return self._eof and not self._buffer def feed_data(self, data): diff --git a/Lib/test/test_asyncio/test_streams.py b/Lib/test/test_asyncio/test_streams.py index d0b7fef8ea52cb..d137cdbb27699d 100644 --- a/Lib/test/test_asyncio/test_streams.py +++ b/Lib/test/test_asyncio/test_streams.py @@ -2,6 +2,7 @@ import contextlib import gc +import io import os import queue import pickle @@ -1254,10 +1255,6 @@ def test_stream_writer_forbidden_ops(self): async def inner(): stream = asyncio.Stream(mode=asyncio.StreamMode.WRITE, _asyncio_internal=True) - with self.assertRaisesRegex(RuntimeError, "The stream is write-only"): - stream.feed_eof() - with self.assertRaisesRegex(RuntimeError, "The stream is write-only"): - stream.at_eof() with self.assertRaisesRegex(RuntimeError, "The stream is write-only"): stream.feed_data(b'data') with self.assertRaisesRegex(RuntimeError, "The stream is write-only"): @@ -1699,5 +1696,49 @@ async def test(): self.loop.run_until_complete(test()) + @unittest.skipUnless(sys.platform != 'win32', + "Don't support pipes for Windows") + def test_read_pipe(self): + async def test(): + rpipe, wpipe = os.pipe() + pipeobj = io.open(rpipe, 'rb', 1024) + + async with asyncio.connect_read_pipe(pipeobj) as stream: + self.assertEqual(stream.mode, asyncio.StreamMode.READ) + + os.write(wpipe, b'1') + data = await stream.readexactly(1) + self.assertEqual(data, b'1') + + os.write(wpipe, b'2345') + data = await stream.readexactly(4) + self.assertEqual(data, b'2345') + os.close(wpipe) + + self.loop.run_until_complete(test()) + + @unittest.skipUnless(sys.platform != 'win32', + "Don't support pipes for Windows") + def test_write_pipe(self): + async def test(): + rpipe, wpipe = os.pipe() + pipeobj = io.open(wpipe, 'wb', 1024) + + async with asyncio.connect_write_pipe(pipeobj) as stream: + self.assertEqual(stream.mode, asyncio.StreamMode.WRITE) + + await stream.write(b'1') + data = os.read(rpipe, 1024) + self.assertEqual(data, b'1') + + await stream.write(b'2345') + data = os.read(rpipe, 1024) + self.assertEqual(data, b'2345') + + os.close(rpipe) + + self.loop.run_until_complete(test()) + + if __name__ == '__main__': unittest.main() From 476de4e3c723c30f6c6da62b3fec88fd6cfb9f6b Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Mon, 27 May 2019 15:31:06 +0300 Subject: [PATCH 74/82] Update Lib/asyncio/streams.py Co-Authored-By: Yury Selivanov --- Lib/asyncio/streams.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Lib/asyncio/streams.py b/Lib/asyncio/streams.py index 01a5abe888c67d..6ab98f7189a8f2 100644 --- a/Lib/asyncio/streams.py +++ b/Lib/asyncio/streams.py @@ -128,7 +128,7 @@ async def _connect_read_pipe(pipe, limit): def connect_write_pipe(pipe, *, limit=_DEFAULT_LIMIT): # Design note: - # Don't use decorator approach but exilicit non-async + # Don't use decorator approach but explicit non-async # function to fail fast and explicitly # if passed arguments don't match the function signature return _ContextManagerHelper(_connect_write_pipe(pipe, limit)) From 2d874c7a236e8143be1dbba9a830675a6bf95339 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Mon, 27 May 2019 15:31:24 +0300 Subject: [PATCH 75/82] Update Lib/asyncio/streams.py Co-Authored-By: Yury Selivanov --- Lib/asyncio/streams.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Lib/asyncio/streams.py b/Lib/asyncio/streams.py index 6ab98f7189a8f2..36f244868f1e06 100644 --- a/Lib/asyncio/streams.py +++ b/Lib/asyncio/streams.py @@ -107,7 +107,7 @@ async def _connect(host, port, def connect_read_pipe(pipe, *, limit=_DEFAULT_LIMIT): # Design note: - # Don't use decorator approach but exilicit non-async + # Don't use decorator approach but explicit non-async # function to fail fast and explicitly # if passed arguments don't match the function signature return _ContextManagerHelper(_connect_read_pipe(pipe, limit)) From d5c207b68295fa771535db94e1849f7a1a989c0f Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Mon, 27 May 2019 15:38:58 +0300 Subject: [PATCH 76/82] Fix notes --- Lib/asyncio/streams.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/Lib/asyncio/streams.py b/Lib/asyncio/streams.py index 01a5abe888c67d..26903cc9e9910f 100644 --- a/Lib/asyncio/streams.py +++ b/Lib/asyncio/streams.py @@ -1246,15 +1246,13 @@ class _OptionalAwait: # if not awaited # It prevents "coroutine is never awaited" message - __slot___ = ('_method', '_args', '_kwargs') + __slots___ = ('_method',) - def __init__(self, method, *args, **kwargs): + def __init__(self, method): self._method = method - self._args = args - self._kwargs = kwargs def __await__(self): - return self._method(*self._args, *self._kwargs).__await__() + return self._method().__await__() class Stream: From 0521b1fe16acb24945f89b8c40fa6637e81b978e Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Mon, 27 May 2019 15:54:50 +0300 Subject: [PATCH 77/82] Return _OptionalAwait from stream.close() --- Lib/asyncio/streams.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Lib/asyncio/streams.py b/Lib/asyncio/streams.py index 10e5a38f2ed3f7..cd1c7ccf7a86fa 100644 --- a/Lib/asyncio/streams.py +++ b/Lib/asyncio/streams.py @@ -1378,7 +1378,7 @@ def can_write_eof(self): def close(self): self._transport.close() - return self._protocol._get_close_waiter(self) + return _OptionalAwait(self.wait_closed) def is_closing(self): return self._transport.is_closing() From 5b4b7fd64a0c55bfe68ed7bc3ced5f0adec0ff8c Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Mon, 27 May 2019 16:13:03 +0300 Subject: [PATCH 78/82] Make _check_read() and _check_write() top-level functions --- Lib/asyncio/streams.py | 34 ++++++++++++++------------- Lib/test/test_asyncio/test_streams.py | 14 +++++------ 2 files changed, 25 insertions(+), 23 deletions(-) diff --git a/Lib/asyncio/streams.py b/Lib/asyncio/streams.py index cd1c7ccf7a86fa..f05130d37fd9b0 100644 --- a/Lib/asyncio/streams.py +++ b/Lib/asyncio/streams.py @@ -32,13 +32,15 @@ class StreamMode(enum.Flag): WRITE = enum.auto() READWRITE = READ | WRITE - def _check_read(self): - if not self & self.READ: - raise RuntimeError("The stream is write-only") - def _check_write(self): - if not self & self.WRITE: - raise RuntimeError("The stream is read-only") +def _check_read(mode): + if not mode & StreamMode.READ: + raise RuntimeError("The stream is write-only") + + +def _check_write(mode): + if not mode & StreamMode.WRITE: + raise RuntimeError("The stream is read-only") class _ContextManagerHelper: @@ -1338,12 +1340,12 @@ def transport(self): return self._transport def write(self, data): - self._mode._check_write() + _check_write(self._mode) self._transport.write(data) return self._fast_drain() def writelines(self, data): - self._mode._check_write() + _check_write(self._mode) self._transport.writelines(data) return self._fast_drain() @@ -1368,7 +1370,7 @@ def _fast_drain(self): return _OptionalAwait(self.drain) def write_eof(self): - self._mode._check_write() + _check_write(self._mode) return self._transport.write_eof() def can_write_eof(self): @@ -1401,7 +1403,7 @@ async def drain(self): w.write(data) await w.drain() """ - self._mode._check_write() + _check_write(self._mode) exc = self.exception() if exc is not None: raise exc @@ -1470,7 +1472,7 @@ def at_eof(self): return self._eof and not self._buffer def feed_data(self, data): - self._mode._check_read() + _check_read(self._mode) assert not self._eof, 'feed_data after feed_eof' if not data: @@ -1536,7 +1538,7 @@ async def readline(self): If stream was paused, this function will automatically resume it if needed. """ - self._mode._check_read() + _check_read(self._mode) sep = b'\n' seplen = len(sep) try: @@ -1572,7 +1574,7 @@ async def readuntil(self, separator=b'\n'): LimitOverrunError exception will be raised, and the data will be left in the internal buffer, so it can be read again. """ - self._mode._check_read() + _check_read(self._mode) seplen = len(separator) if seplen == 0: raise ValueError('Separator should be at least one-byte string') @@ -1664,7 +1666,7 @@ async def read(self, n=-1): If stream was paused, this function will automatically resume it if needed. """ - self._mode._check_read() + _check_read(self._mode) if self._exception is not None: raise self._exception @@ -1710,7 +1712,7 @@ async def readexactly(self, n): If stream was paused, this function will automatically resume it if needed. """ - self._mode._check_read() + _check_read(self._mode) if n < 0: raise ValueError('readexactly size can not be less than zero') @@ -1738,7 +1740,7 @@ async def readexactly(self, n): return data def __aiter__(self): - self._mode._check_read() + _check_read(self._mode) return self async def __anext__(self): diff --git a/Lib/test/test_asyncio/test_streams.py b/Lib/test/test_asyncio/test_streams.py index d137cdbb27699d..78e16d5d48a1da 100644 --- a/Lib/test/test_asyncio/test_streams.py +++ b/Lib/test/test_asyncio/test_streams.py @@ -18,7 +18,7 @@ ssl = None import asyncio -from asyncio.streams import _StreamProtocol +from asyncio.streams import _StreamProtocol, _check_read, _check_write from test.test_asyncio import utils as test_utils @@ -28,20 +28,20 @@ def tearDownModule(): class StreamModeTests(unittest.TestCase): def test__check_read_ok(self): - self.assertIsNone(asyncio.StreamMode.READ._check_read()) - self.assertIsNone(asyncio.StreamMode.READWRITE._check_read()) + self.assertIsNone(_check_read(asyncio.StreamMode.READ)) + self.assertIsNone(_check_read(asyncio.StreamMode.READWRITE)) def test__check_read_fail(self): with self.assertRaisesRegex(RuntimeError, "The stream is write-only"): - asyncio.StreamMode.WRITE._check_read() + _check_read(asyncio.StreamMode.WRITE) def test__check_write_ok(self): - self.assertIsNone(asyncio.StreamMode.WRITE._check_write()) - self.assertIsNone(asyncio.StreamMode.READWRITE._check_write()) + self.assertIsNone(_check_write(asyncio.StreamMode.WRITE)) + self.assertIsNone(_check_write(asyncio.StreamMode.READWRITE)) def test__check_write_fail(self): with self.assertRaisesRegex(RuntimeError, "The stream is read-only"): - asyncio.StreamMode.READ._check_write() + _check_write(asyncio.StreamMode.READ) class StreamTests(test_utils.TestCase): From 567f1548eaea1d81fa412ecd2e18ce8581aac725 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Mon, 27 May 2019 16:41:56 +0300 Subject: [PATCH 79/82] Return empty tuple instead of list if StreamServer.sockets is empty --- Lib/asyncio/streams.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Lib/asyncio/streams.py b/Lib/asyncio/streams.py index f05130d37fd9b0..68e6e7b52a3dda 100644 --- a/Lib/asyncio/streams.py +++ b/Lib/asyncio/streams.py @@ -266,7 +266,7 @@ def is_bound(self): def sockets(self): # multiple value for socket bound to both IPv4 and IPv6 families if self._server_impl is None: - return [] + return () return self._server_impl.sockets def is_serving(self): From 8cf6bf18dd79e5171fb781ceb79220aad7de6259 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Mon, 27 May 2019 16:44:09 +0300 Subject: [PATCH 80/82] Rename private helpers --- Lib/asyncio/streams.py | 24 ++++++++++++------------ Lib/test/test_asyncio/test_streams.py | 22 +++++++++++----------- 2 files changed, 23 insertions(+), 23 deletions(-) diff --git a/Lib/asyncio/streams.py b/Lib/asyncio/streams.py index 68e6e7b52a3dda..480f1a3fdd74ed 100644 --- a/Lib/asyncio/streams.py +++ b/Lib/asyncio/streams.py @@ -33,12 +33,12 @@ class StreamMode(enum.Flag): READWRITE = READ | WRITE -def _check_read(mode): +def _ensure_can_read(mode): if not mode & StreamMode.READ: raise RuntimeError("The stream is write-only") -def _check_write(mode): +def _ensure_can_write(mode): if not mode & StreamMode.WRITE: raise RuntimeError("The stream is read-only") @@ -1340,12 +1340,12 @@ def transport(self): return self._transport def write(self, data): - _check_write(self._mode) + _ensure_can_write(self._mode) self._transport.write(data) return self._fast_drain() def writelines(self, data): - _check_write(self._mode) + _ensure_can_write(self._mode) self._transport.writelines(data) return self._fast_drain() @@ -1370,7 +1370,7 @@ def _fast_drain(self): return _OptionalAwait(self.drain) def write_eof(self): - _check_write(self._mode) + _ensure_can_write(self._mode) return self._transport.write_eof() def can_write_eof(self): @@ -1403,7 +1403,7 @@ async def drain(self): w.write(data) await w.drain() """ - _check_write(self._mode) + _ensure_can_write(self._mode) exc = self.exception() if exc is not None: raise exc @@ -1472,7 +1472,7 @@ def at_eof(self): return self._eof and not self._buffer def feed_data(self, data): - _check_read(self._mode) + _ensure_can_read(self._mode) assert not self._eof, 'feed_data after feed_eof' if not data: @@ -1538,7 +1538,7 @@ async def readline(self): If stream was paused, this function will automatically resume it if needed. """ - _check_read(self._mode) + _ensure_can_read(self._mode) sep = b'\n' seplen = len(sep) try: @@ -1574,7 +1574,7 @@ async def readuntil(self, separator=b'\n'): LimitOverrunError exception will be raised, and the data will be left in the internal buffer, so it can be read again. """ - _check_read(self._mode) + _ensure_can_read(self._mode) seplen = len(separator) if seplen == 0: raise ValueError('Separator should be at least one-byte string') @@ -1666,7 +1666,7 @@ async def read(self, n=-1): If stream was paused, this function will automatically resume it if needed. """ - _check_read(self._mode) + _ensure_can_read(self._mode) if self._exception is not None: raise self._exception @@ -1712,7 +1712,7 @@ async def readexactly(self, n): If stream was paused, this function will automatically resume it if needed. """ - _check_read(self._mode) + _ensure_can_read(self._mode) if n < 0: raise ValueError('readexactly size can not be less than zero') @@ -1740,7 +1740,7 @@ async def readexactly(self, n): return data def __aiter__(self): - _check_read(self._mode) + _ensure_can_read(self._mode) return self async def __anext__(self): diff --git a/Lib/test/test_asyncio/test_streams.py b/Lib/test/test_asyncio/test_streams.py index 78e16d5d48a1da..623dbbf63f81b3 100644 --- a/Lib/test/test_asyncio/test_streams.py +++ b/Lib/test/test_asyncio/test_streams.py @@ -18,7 +18,7 @@ ssl = None import asyncio -from asyncio.streams import _StreamProtocol, _check_read, _check_write +from asyncio.streams import _StreamProtocol, _ensure_can_read, _ensure_can_write from test.test_asyncio import utils as test_utils @@ -27,21 +27,21 @@ def tearDownModule(): class StreamModeTests(unittest.TestCase): - def test__check_read_ok(self): - self.assertIsNone(_check_read(asyncio.StreamMode.READ)) - self.assertIsNone(_check_read(asyncio.StreamMode.READWRITE)) + def test__ensure_can_read_ok(self): + self.assertIsNone(_ensure_can_read(asyncio.StreamMode.READ)) + self.assertIsNone(_ensure_can_read(asyncio.StreamMode.READWRITE)) - def test__check_read_fail(self): + def test__ensure_can_read_fail(self): with self.assertRaisesRegex(RuntimeError, "The stream is write-only"): - _check_read(asyncio.StreamMode.WRITE) + _ensure_can_read(asyncio.StreamMode.WRITE) - def test__check_write_ok(self): - self.assertIsNone(_check_write(asyncio.StreamMode.WRITE)) - self.assertIsNone(_check_write(asyncio.StreamMode.READWRITE)) + def test__ensure_can_write_ok(self): + self.assertIsNone(_ensure_can_write(asyncio.StreamMode.WRITE)) + self.assertIsNone(_ensure_can_write(asyncio.StreamMode.READWRITE)) - def test__check_write_fail(self): + def test__ensure_can_write_fail(self): with self.assertRaisesRegex(RuntimeError, "The stream is read-only"): - _check_write(asyncio.StreamMode.READ) + _ensure_can_write(asyncio.StreamMode.READ) class StreamTests(test_utils.TestCase): From 99f483961e88f44e40d2890fb065cffbb05b083d Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Mon, 27 May 2019 19:21:49 +0300 Subject: [PATCH 81/82] Make more detailed NEWS --- .../next/Library/2019-05-14-12-25-44.bpo-36889.MChPqP.rst | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/Misc/NEWS.d/next/Library/2019-05-14-12-25-44.bpo-36889.MChPqP.rst b/Misc/NEWS.d/next/Library/2019-05-14-12-25-44.bpo-36889.MChPqP.rst index 1009eb6e841669..d08c0e287edfb4 100644 --- a/Misc/NEWS.d/next/Library/2019-05-14-12-25-44.bpo-36889.MChPqP.rst +++ b/Misc/NEWS.d/next/Library/2019-05-14-12-25-44.bpo-36889.MChPqP.rst @@ -1,2 +1,6 @@ -Merge asyncio.StreamReader and asyncio.StreamWriter into asyncio.Stream -class with readonly, writeonly and readwrite modes. +Introduce :class:`asyncio.Stream` class that merges :class:`asyncio.StreamReader` and :class:`asyncio.StreamWriter` functionality. +:class:`asyncio.Stream` can work in readonly, writeonly and readwrite modes. +Provide :func:`asyncio.connect`, :func:`asyncio.connect_unix`, :func:`asyncio.connect_read_pipe` and :func:`asyncio.connect_write_pipe` factories to open :class:`asyncio.Stream` connections. Provide :class:`asyncio.StreamServer` and :class:`UnixStreamServer` to serve servers with asyncio.Stream API. +Modify :func:`asyncio.create_subprocess_shell` and :func:`asyncio.create_subprocess_exec` to use :class:`asyncio.Stream` instead of deprecated :class:`StreamReader` and :class:`StreamWriter`. +Deprecate :class:`asyncio.StreamReader` and :class:`asyncio.StreamWriter`. +Deprecate usage of private classes, e.g. :class:`asyncio.FlowControlMixing` and :class:`asyncio.StreamReaderProtocol` outside of asyncio package. From 4ac5ce1153a76c328379bae22e793150dfb39a8f Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Mon, 27 May 2019 20:13:40 +0300 Subject: [PATCH 82/82] Fix tests --- Lib/test/test_asyncio/test_streams.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Lib/test/test_asyncio/test_streams.py b/Lib/test/test_asyncio/test_streams.py index 623dbbf63f81b3..df3d7e7dfa455c 100644 --- a/Lib/test/test_asyncio/test_streams.py +++ b/Lib/test/test_asyncio/test_streams.py @@ -1678,7 +1678,7 @@ async def serve(stream): async def test(): srv = asyncio.StreamServer(serve, '127.0.0.1', 0) await srv.bind() - self.assertRegex(repr(srv), r'') + self.assertRegex(repr(srv), r'') await srv.close() self.loop.run_until_complete(test()) @@ -1690,7 +1690,7 @@ async def serve(stream): async def test(): srv = asyncio.StreamServer(serve, '127.0.0.1', 0) await srv.start_serving() - self.assertRegex(repr(srv), r'') + self.assertRegex(repr(srv), r'') await srv.close() self.loop.run_until_complete(test())