diff --git a/Lib/asyncio/base_events.py b/Lib/asyncio/base_events.py index c8a2f9f25634ef..13d4df1110c0e8 100644 --- a/Lib/asyncio/base_events.py +++ b/Lib/asyncio/base_events.py @@ -447,6 +447,38 @@ def create_task(self, coro, *, name=None, context=None): return task + def eager_task_factory(self, coro, *, name=None, context=None): + """Start a coroutine. + + This runs the coroutine until it first suspends itself. + + If it runs till completion or fails without suspending, + return a future with the result or exception. + + Otherwise schedule the resumption and return a task. + """ + self._check_closed() + # Do not go through the task factory. + # This _is_ the task factory. + if tasks.Task is not tasks._PyTask: + task = tasks.Task(coro, loop=self, name=name, context=context) + else: + try: + yield_result = coro.send(None) + except BaseException as exc: + fut = self.create_future() + # XXX What about AsyncStopIteration? + if isinstance(exc, StopIteration): + fut.set_result(exc.value) + else: + fut.set_exception(exc) + return fut + task = tasks.Task(coro, loop=self, name=name, context=context, + yield_result=yield_result) + if task._source_traceback: + del task._source_traceback[-1] + return task + def set_task_factory(self, factory): """Set a task factory that will be used by loop.create_task(). diff --git a/Lib/asyncio/taskgroups.py b/Lib/asyncio/taskgroups.py index 5d5e2a8a85dd48..e7a5df3b7d0827 100644 --- a/Lib/asyncio/taskgroups.py +++ b/Lib/asyncio/taskgroups.py @@ -141,11 +141,14 @@ def create_task(self, coro, *, name=None, context=None): raise RuntimeError(f"TaskGroup {self!r} is finished") if self._aborting: raise RuntimeError(f"TaskGroup {self!r} is shutting down") - if context is None: + if hasattr(self._loop, "eager_task_factory"): + task = self._loop.eager_task_factory(coro, name=name, context=context) + elif context is None: task = self._loop.create_task(coro) else: task = self._loop.create_task(coro, context=context) - tasks._set_task_name(task, name) + if not task.done(): # If it's done already, it's a future + tasks._set_task_name(task, name) task.add_done_callback(self._on_task_done) self._tasks.add(task) return task diff --git a/Lib/asyncio/tasks.py b/Lib/asyncio/tasks.py index 571013745aa03a..678f84cf04243b 100644 --- a/Lib/asyncio/tasks.py +++ b/Lib/asyncio/tasks.py @@ -75,6 +75,8 @@ def _set_task_name(task, name): set_name(name) +_NOT_SET = object() + class Task(futures._PyFuture): # Inherit Python Task implementation # from a Python Future implementation. @@ -93,7 +95,8 @@ class Task(futures._PyFuture): # Inherit Python Task implementation # status is still pending _log_destroy_pending = True - def __init__(self, coro, *, loop=None, name=None, context=None): + def __init__(self, coro, *, loop=None, name=None, context=None, + yield_result=_NOT_SET): super().__init__(loop=loop) if self._source_traceback: del self._source_traceback[-1] @@ -117,7 +120,10 @@ def __init__(self, coro, *, loop=None, name=None, context=None): else: self._context = context - self._loop.call_soon(self.__step, context=self._context) + if yield_result is _NOT_SET: + self._loop.call_soon(self.__step, context=self._context) + else: + self.__step2(yield_result) _register_task(self) def __del__(self): @@ -287,6 +293,12 @@ def __step(self, exc=None): except BaseException as exc: super().set_exception(exc) else: + self.__step2(result) + finally: + _leave_task(self._loop, self) + self = None # Needed to break cycles when an exception occurs. + + def __step2(self, result): blocking = getattr(result, '_asyncio_future_blocking', None) if blocking is not None: # Yielded Future must come from Future.__iter__(). @@ -333,9 +345,6 @@ def __step(self, exc=None): new_exc = RuntimeError(f'Task got bad yield: {result!r}') self._loop.call_soon( self.__step, new_exc, context=self._context) - finally: - _leave_task(self._loop, self) - self = None # Needed to break cycles when an exception occurs. def __wakeup(self, future): try: