Skip to content

Allow to terminate an epoch without firing Events.EPOCH_COMPLETED #3313

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Dec 9, 2024
Merged
79 changes: 56 additions & 23 deletions ignite/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,12 @@ def __init__(self, process_function: Callable[["Engine", Any], Any]):
self.logger = logging.getLogger(__name__ + "." + self.__class__.__name__)
self._process_function = process_function
self.last_event_name: Optional[Events] = None
self.should_terminate = False
self.skip_completed_after_termination = False
self.should_terminate_single_epoch = False
# should_terminate flag: False - don't terminate, True - terminate,
# "skip_completed" - terminate and skip the event "COMPLETED"
self.should_terminate: Union[bool, str] = False
# should_terminate_single_epoch flag: False - don't terminate, True - terminate,
# "skip_epoch_completed" - terminate and skip the event "EPOCH_COMPLETED"
self.should_terminate_single_epoch: Union[bool, str] = False
self.should_interrupt = False
self.state = State()
self._state_dict_user_keys: List[str] = []
Expand Down Expand Up @@ -546,7 +549,7 @@ def terminate(self, skip_completed: bool = False) -> None:
- ...
- Terminating event
- :attr:`~ignite.engine.events.Events.TERMINATE`
- :attr:`~ignite.engine.events.Events.COMPLETED`
- :attr:`~ignite.engine.events.Events.COMPLETED` (unless `skip_completed=True`)

Args:
skip_completed: if True, the event :attr:`~ignite.engine.events.Events.COMPLETED` is not fired after
Expand Down Expand Up @@ -625,25 +628,31 @@ def terminate():
Added `skip_completed` flag
"""
self.logger.info("Terminate signaled. Engine will stop after current iteration is finished.")
self.should_terminate = True
self.skip_completed_after_termination = skip_completed
self.should_terminate = "skip_completed" if skip_completed else True

def terminate_epoch(self) -> None:
def terminate_epoch(self, skip_epoch_completed: bool = False) -> None:
"""Sends terminate signal to the engine, so that it terminates the current epoch. The run
continues from the next epoch. The following events are triggered:

- ...
- Event on which ``terminate_epoch`` method is called
- :attr:`~ignite.engine.events.Events.TERMINATE_SINGLE_EPOCH`
- :attr:`~ignite.engine.events.Events.EPOCH_COMPLETED`
- :attr:`~ignite.engine.events.Events.EPOCH_COMPLETED` (unless `skip_epoch_completed=True`)
- :attr:`~ignite.engine.events.Events.EPOCH_STARTED`
- ...

Args:
skip_epoch_completed: if True, the event :attr:`~ignite.engine.events.Events.EPOCH_COMPLETED`
is not fired after :attr:`~ignite.engine.events.Events.TERMINATE_SINGLE_EPOCH`. Default is False.

.. versionchanged:: 0.5.2
Added `skip_epoch_completed` flag
"""
self.logger.info(
"Terminate current epoch is signaled. "
"Current epoch iteration will stop after current iteration is finished."
)
self.should_terminate_single_epoch = True
self.should_terminate_single_epoch = "skip_epoch_completed" if skip_epoch_completed else True

def _handle_exception(self, e: BaseException) -> None:
if Events.EXCEPTION_RAISED in self._event_handlers:
Expand Down Expand Up @@ -982,11 +991,17 @@ def _internal_run_as_gen(self) -> Generator[Any, None, State]:
# time is available for handlers but must be updated after fire
self.state.times[Events.EPOCH_COMPLETED.name] = epoch_time_taken

handlers_start_time = time.time()
self._fire_event(Events.EPOCH_COMPLETED)
epoch_time_taken += time.time() - handlers_start_time
# update time wrt handlers
self.state.times[Events.EPOCH_COMPLETED.name] = epoch_time_taken
if self.should_terminate_single_epoch != "skip_epoch_completed": # type: ignore[comparison-overlap]
handlers_start_time = time.time()
self._fire_event(Events.EPOCH_COMPLETED)
epoch_time_taken += time.time() - handlers_start_time
# update time wrt handlers
self.state.times[Events.EPOCH_COMPLETED.name] = epoch_time_taken

if self.should_terminate_single_epoch:
# We skip raising _EngineTerminateSingleEpochException exception on Events.EPOCH_COMPLETED
# as epoch is already completed and nothing to terminate
self.should_terminate_single_epoch = False
yield from self._maybe_terminate_or_interrupt()

hours, mins, secs = _to_hours_mins_secs(epoch_time_taken)
Expand All @@ -997,12 +1012,19 @@ def _internal_run_as_gen(self) -> Generator[Any, None, State]:
except _EngineTerminateException:
self._fire_event(Events.TERMINATE)

except _EngineTerminateSingleEpochException:
raise RuntimeError(
"The method terminate_epoch() should not be called on Event.STARTED or Event.EPOCH_STARTED."
"If this is a desired behaviour, please open a feature request on"
"https://github.com/pytorch/ignite/issues/new/choose"
)

time_taken = time.time() - start_time
# time is available for handlers but must be updated after fire
self.state.times[Events.COMPLETED.name] = time_taken

# do not fire Events.COMPLETED if we terminated the run with flag `skip_completed=True`
if not (self.should_terminate and self.skip_completed_after_termination):
if self.should_terminate != "skip_completed": # type: ignore[comparison-overlap]
handlers_start_time = time.time()
self._fire_event(Events.COMPLETED)
time_taken += time.time() - handlers_start_time
Expand Down Expand Up @@ -1121,7 +1143,6 @@ def _run_once_on_dataset_as_gen(self) -> Generator[State, None, float]:

except _EngineTerminateSingleEpochException:
self._fire_event(Events.TERMINATE_SINGLE_EPOCH, iter_counter=iter_counter)
self.should_terminate_single_epoch = False
self._setup_dataloader_iter()

except _EngineTerminateException as e:
Expand Down Expand Up @@ -1167,11 +1188,17 @@ def _internal_run_legacy(self) -> State:
# time is available for handlers but must be updated after fire
self.state.times[Events.EPOCH_COMPLETED.name] = epoch_time_taken

handlers_start_time = time.time()
self._fire_event(Events.EPOCH_COMPLETED)
epoch_time_taken += time.time() - handlers_start_time
# update time wrt handlers
self.state.times[Events.EPOCH_COMPLETED.name] = epoch_time_taken
if self.should_terminate_single_epoch != "skip_epoch_completed": # type: ignore[comparison-overlap]
handlers_start_time = time.time()
self._fire_event(Events.EPOCH_COMPLETED)
epoch_time_taken += time.time() - handlers_start_time
# update time wrt handlers
self.state.times[Events.EPOCH_COMPLETED.name] = epoch_time_taken

if self.should_terminate_single_epoch:
# We skip raising _EngineTerminateSingleEpochException exception on Events.EPOCH_COMPLETED
# as epoch is already completed and nothing to terminate
self.should_terminate_single_epoch = False
self._maybe_terminate_legacy()

hours, mins, secs = _to_hours_mins_secs(epoch_time_taken)
Expand All @@ -1182,12 +1209,19 @@ def _internal_run_legacy(self) -> State:
except _EngineTerminateException:
self._fire_event(Events.TERMINATE)

except _EngineTerminateSingleEpochException:
raise RuntimeError(
"The method terminate_epoch() should not be called on Event.STARTED or Event.EPOCH_STARTED."
"If this is a desired behaviour, please open a feature request on"
"https://github.com/pytorch/ignite/issues/new/choose"
)

time_taken = time.time() - start_time
# time is available for handlers but must be updated after fire
self.state.times[Events.COMPLETED.name] = time_taken

# do not fire Events.COMPLETED if we terminated the run with flag `skip_completed=True`
if not (self.should_terminate and self.skip_completed_after_termination):
if self.should_terminate != "skip_completed": # type: ignore[comparison-overlap]
handlers_start_time = time.time()
self._fire_event(Events.COMPLETED)
time_taken += time.time() - handlers_start_time
Expand Down Expand Up @@ -1292,7 +1326,6 @@ def _run_once_on_dataset_legacy(self) -> float:

except _EngineTerminateSingleEpochException:
self._fire_event(Events.TERMINATE_SINGLE_EPOCH, iter_counter=iter_counter)
self.should_terminate_single_epoch = False
self._setup_dataloader_iter()

except _EngineTerminateException as e:
Expand Down
12 changes: 9 additions & 3 deletions ignite/engine/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,8 +259,9 @@ class Events(EventEnum):
- TERMINATE_SINGLE_EPOCH : triggered when the run is about to end the current epoch,
after receiving a :meth:`~ignite.engine.engine.Engine.terminate_epoch()` or
:meth:`~ignite.engine.engine.Engine.terminate()` call.
- EPOCH_COMPLETED : triggered when the epoch is ended. Note that this is triggered even
when :meth:`~ignite.engine.engine.Engine.terminate_epoch()` is called.
- EPOCH_COMPLETED : triggered when the epoch is ended. This is triggered even
when :meth:`~ignite.engine.engine.Engine.terminate_epoch()` is called,
unless the flag `skip_epoch_completed` is set to True.

- TERMINATE : triggered when the run is about to end completely,
after receiving :meth:`~ignite.engine.engine.Engine.terminate()` call.
Expand All @@ -272,7 +273,7 @@ class Events(EventEnum):
The table below illustrates which events are triggered when various termination methods are called.

.. list-table::
:widths: 35 38 28 20 20
:widths: 38 38 28 20 20
:header-rows: 1

* - Method
Expand All @@ -290,6 +291,11 @@ class Events(EventEnum):
- ✔
- ✗
- ✔
* - :meth:`~ignite.engine.engine.Engine.terminate_epoch()` with `skip_epoch_completed=True`
- ✔
- ✗
- ✗
- ✔
* - :meth:`~ignite.engine.engine.Engine.terminate()`
- ✗
- ✔
Expand Down
71 changes: 54 additions & 17 deletions tests/ignite/engine/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,13 @@ def set_interrupt_resume_enabled(self, interrupt_resume_enabled):
def test_terminate(self, skip_completed):
engine = Engine(lambda e, b: 1)
assert not engine.should_terminate
assert not engine.skip_completed_after_termination

engine.terminate(skip_completed)
assert engine.should_terminate
assert engine.skip_completed_after_termination == skip_completed

if skip_completed:
assert engine.should_terminate == "skip_completed"
else:
assert engine.should_terminate == True # noqa: E712

def test_invalid_process_raises_with_invalid_signature(self):
with pytest.raises(ValueError, match=r"Engine must be given a processing function in order to run"):
Expand Down Expand Up @@ -292,16 +295,19 @@ def assert_no_exceptions(ee):
assert engine.called_events[0] == (0, 0, Events.STARTED)
assert engine._dataloader_iter is None

@pytest.mark.parametrize("data, epoch_length", [(None, 10), (range(10), None)])
def test_terminate_epoch_stops_mid_epoch(self, data, epoch_length):
@pytest.mark.parametrize(
"data, epoch_length, skip_epoch_completed",
[(None, 10, False), (range(10), None, False), (None, 10, True), (range(10), None, True)],
)
def test_terminate_epoch_stops_mid_epoch(self, data, epoch_length, skip_epoch_completed):
real_epoch_length = epoch_length if data is None else len(data)
iteration_to_stop = real_epoch_length + 4

engine = Engine(MagicMock(return_value=1))

def start_of_iteration_handler(engine):
if engine.state.iteration == iteration_to_stop:
engine.terminate_epoch()
engine.terminate_epoch(skip_epoch_completed)

max_epochs = 3
engine.add_event_handler(Events.ITERATION_STARTED, start_of_iteration_handler)
Expand All @@ -312,15 +318,23 @@ def start_of_iteration_handler(engine):
assert state.epoch == max_epochs

@pytest.mark.parametrize(
"terminate_epoch_event, i",
"terminate_epoch_event, i, skip_epoch_completed",
[
(Events.GET_BATCH_STARTED(once=12), 12),
(Events.GET_BATCH_COMPLETED(once=12), 12),
(Events.ITERATION_STARTED(once=14), 14),
(Events.ITERATION_COMPLETED(once=14), 14),
(Events.GET_BATCH_STARTED(once=12), 12, False),
(Events.GET_BATCH_COMPLETED(once=12), 12, False),
(Events.ITERATION_STARTED(once=14), 14, False),
(Events.ITERATION_COMPLETED(once=14), 14, False),
(Events.GET_BATCH_STARTED(once=12), 12, True),
(Events.GET_BATCH_COMPLETED(once=12), 12, True),
(Events.ITERATION_STARTED(once=14), 14, True),
(Events.ITERATION_COMPLETED(once=14), 14, True),
(Events.STARTED, 30, False),
(Events.STARTED, 30, True),
(Events.EPOCH_STARTED(once=2), 10, False),
(Events.EPOCH_STARTED(once=2), 10, True),
],
)
def test_terminate_epoch_events_sequence(self, terminate_epoch_event, i):
def test_terminate_epoch_events_sequence(self, terminate_epoch_event, i, skip_epoch_completed):
engine = RecordedEngine(MagicMock(return_value=1))
data = range(10)
max_epochs = 3
Expand All @@ -331,31 +345,54 @@ def test_terminate_epoch_events_sequence(self, terminate_epoch_event, i):

@engine.on(terminate_epoch_event)
def call_terminate_epoch():
assert not engine.should_terminate_single_epoch
nonlocal call_count
if call_count < 1:
engine.terminate_epoch()
engine.terminate_epoch(skip_epoch_completed)
if skip_epoch_completed:
assert engine.should_terminate_single_epoch == "skip_epoch_completed"
else:
assert engine.should_terminate_single_epoch == True # noqa: E712

call_count += 1

@engine.on(Events.EPOCH_STARTED)
def check_skip_reset():
if terminate_epoch_event != Events.EPOCH_STARTED:
assert engine.should_terminate_single_epoch == False # noqa: E712

@engine.on(Events.TERMINATE_SINGLE_EPOCH)
def check_previous_events(iter_counter):
e = i // len(data) + 1

assert engine.called_events[0] == (0, 0, Events.STARTED)
assert engine.called_events[-2] == (e, i, terminate_epoch_event)
assert engine.called_events[-1] == (e, i, Events.TERMINATE_SINGLE_EPOCH)
if skip_epoch_completed:
assert engine.should_terminate_single_epoch == "skip_epoch_completed"
else:
assert engine.should_terminate_single_epoch == True # noqa: E712

@engine.on(Events.EPOCH_COMPLETED)
def check_previous_events2():
e = i // len(data) + 1
if e == engine.state.epoch and i == engine.state.iteration:
assert not skip_epoch_completed
assert isinstance(engine.should_terminate_single_epoch, bool)
assert engine.called_events[-3] == (e, i, terminate_epoch_event)
assert engine.called_events[-2] == (e, i, Events.TERMINATE_SINGLE_EPOCH)
assert engine.called_events[-1] == (e, i, Events.EPOCH_COMPLETED)

engine.run(data, max_epochs=max_epochs)
if terminate_epoch_event in [Events.STARTED, Events.EPOCH_STARTED]:
with pytest.raises(RuntimeError):
engine.run(data, max_epochs=max_epochs)
else:
engine.run(data, max_epochs=max_epochs)

assert engine.state.epoch == max_epochs
assert (max_epochs - 1) * len(data) < engine.state.iteration < max_epochs * len(data)

assert engine.state.epoch == max_epochs
assert (max_epochs - 1) * len(data) < engine.state.iteration < max_epochs * len(data)
epoch_completed_events = [e for e in engine.called_events if e[2] == Events.EPOCH_COMPLETED.name]
assert len(epoch_completed_events) == max_epochs - skip_epoch_completed

@pytest.mark.parametrize("data", [None, "mock_data_loader"])
def test_iteration_events_are_fired(self, data):
Expand Down
Loading