Skip to content

Commit adb2a01

Browse files
committed
Added optional flag skip_epoch_completed to Engine.terminate_epoch()
1 parent 6f8ad2a commit adb2a01

File tree

3 files changed

+59
-25
lines changed

3 files changed

+59
-25
lines changed

ignite/engine/engine.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ def __init__(self, process_function: Callable[["Engine", Any], Any]):
142142
self.should_terminate = False
143143
self.skip_completed_after_termination = False
144144
self.should_terminate_single_epoch = False
145+
self._skip_epoch_completed_after_termination = False
145146
self.should_interrupt = False
146147
self.state = State()
147148
self._state_dict_user_keys: List[str] = []
@@ -628,7 +629,7 @@ def terminate():
628629
self.should_terminate = True
629630
self.skip_completed_after_termination = skip_completed
630631

631-
def terminate_epoch(self) -> None:
632+
def terminate_epoch(self, skip_epoch_completed: bool = False) -> None:
632633
"""Sends terminate signal to the engine, so that it terminates the current epoch. The run
633634
continues from the next epoch. The following events are triggered:
634635
@@ -638,12 +639,17 @@ def terminate_epoch(self) -> None:
638639
- :attr:`~ignite.engine.events.Events.EPOCH_COMPLETED`
639640
- :attr:`~ignite.engine.events.Events.EPOCH_STARTED`
640641
- ...
642+
643+
Args:
644+
skip_epoch_completed: if True, the event :attr:`~ignite.engine.events.Events.EPOCH_COMPLETED`
645+
is not fired after :attr:`~ignite.engine.events.Events.TERMINATE_SINGLE_EPOCH`. Default is False.
641646
"""
642647
self.logger.info(
643648
"Terminate current epoch is signaled. "
644649
"Current epoch iteration will stop after current iteration is finished."
645650
)
646651
self.should_terminate_single_epoch = True
652+
self._skip_epoch_completed_after_termination = skip_epoch_completed
647653

648654
def _handle_exception(self, e: BaseException) -> None:
649655
if Events.EXCEPTION_RAISED in self._event_handlers:
@@ -982,11 +988,15 @@ def _internal_run_as_gen(self) -> Generator[Any, None, State]:
982988
# time is available for handlers but must be updated after fire
983989
self.state.times[Events.EPOCH_COMPLETED.name] = epoch_time_taken
984990

985-
handlers_start_time = time.time()
986-
self._fire_event(Events.EPOCH_COMPLETED)
987-
epoch_time_taken += time.time() - handlers_start_time
988-
# update time wrt handlers
989-
self.state.times[Events.EPOCH_COMPLETED.name] = epoch_time_taken
991+
if not self._skip_epoch_completed_after_termination:
992+
handlers_start_time = time.time()
993+
self._fire_event(Events.EPOCH_COMPLETED)
994+
epoch_time_taken += time.time() - handlers_start_time
995+
# update time wrt handlers
996+
self.state.times[Events.EPOCH_COMPLETED.name] = epoch_time_taken
997+
else:
998+
self._skip_epoch_completed_after_termination = False
999+
9901000
yield from self._maybe_terminate_or_interrupt()
9911001

9921002
hours, mins, secs = _to_hours_mins_secs(epoch_time_taken)
@@ -1167,11 +1177,15 @@ def _internal_run_legacy(self) -> State:
11671177
# time is available for handlers but must be updated after fire
11681178
self.state.times[Events.EPOCH_COMPLETED.name] = epoch_time_taken
11691179

1170-
handlers_start_time = time.time()
1171-
self._fire_event(Events.EPOCH_COMPLETED)
1172-
epoch_time_taken += time.time() - handlers_start_time
1173-
# update time wrt handlers
1174-
self.state.times[Events.EPOCH_COMPLETED.name] = epoch_time_taken
1180+
if not self._skip_epoch_completed_after_termination:
1181+
handlers_start_time = time.time()
1182+
self._fire_event(Events.EPOCH_COMPLETED)
1183+
epoch_time_taken += time.time() - handlers_start_time
1184+
# update time wrt handlers
1185+
self.state.times[Events.EPOCH_COMPLETED.name] = epoch_time_taken
1186+
else:
1187+
self._skip_epoch_completed_after_termination = False
1188+
11751189
self._maybe_terminate_legacy()
11761190

11771191
hours, mins, secs = _to_hours_mins_secs(epoch_time_taken)

ignite/engine/events.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -259,8 +259,9 @@ class Events(EventEnum):
259259
- TERMINATE_SINGLE_EPOCH : triggered when the run is about to end the current epoch,
260260
after receiving a :meth:`~ignite.engine.engine.Engine.terminate_epoch()` or
261261
:meth:`~ignite.engine.engine.Engine.terminate()` call.
262-
- EPOCH_COMPLETED : triggered when the epoch is ended. Note that this is triggered even
263-
when :meth:`~ignite.engine.engine.Engine.terminate_epoch()` is called.
262+
- EPOCH_COMPLETED : triggered when the epoch is ended. This is triggered even
263+
when :meth:`~ignite.engine.engine.Engine.terminate_epoch()` is called,
264+
unless the flag `skip_epoch_completed` is set to True.
264265
265266
- TERMINATE : triggered when the run is about to end completely,
266267
after receiving :meth:`~ignite.engine.engine.Engine.terminate()` call.
@@ -272,7 +273,7 @@ class Events(EventEnum):
272273
The table below illustrates which events are triggered when various termination methods are called.
273274
274275
.. list-table::
275-
:widths: 35 38 28 20 20
276+
:widths: 38 38 28 20 20
276277
:header-rows: 1
277278
278279
* - Method
@@ -290,6 +291,11 @@ class Events(EventEnum):
290291
- ✔
291292
- ✗
292293
- ✔
294+
* - :meth:`~ignite.engine.engine.Engine.terminate_epoch()` with `skip_epoch_completed=True`
295+
- ✔
296+
- ✗
297+
- ✗
298+
- ✔
293299
* - :meth:`~ignite.engine.engine.Engine.terminate()`
294300
- ✗
295301
- ✔

tests/ignite/engine/test_engine.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -292,16 +292,19 @@ def assert_no_exceptions(ee):
292292
assert engine.called_events[0] == (0, 0, Events.STARTED)
293293
assert engine._dataloader_iter is None
294294

295-
@pytest.mark.parametrize("data, epoch_length", [(None, 10), (range(10), None)])
296-
def test_terminate_epoch_stops_mid_epoch(self, data, epoch_length):
295+
@pytest.mark.parametrize(
296+
"data, epoch_length, skip_epoch_completed",
297+
[(None, 10, False), (range(10), None, False), (None, 10, True), (range(10), None, True)],
298+
)
299+
def test_terminate_epoch_stops_mid_epoch(self, data, epoch_length, skip_epoch_completed):
297300
real_epoch_length = epoch_length if data is None else len(data)
298301
iteration_to_stop = real_epoch_length + 4
299302

300303
engine = Engine(MagicMock(return_value=1))
301304

302305
def start_of_iteration_handler(engine):
303306
if engine.state.iteration == iteration_to_stop:
304-
engine.terminate_epoch()
307+
engine.terminate_epoch(skip_epoch_completed)
305308

306309
max_epochs = 3
307310
engine.add_event_handler(Events.ITERATION_STARTED, start_of_iteration_handler)
@@ -312,15 +315,19 @@ def start_of_iteration_handler(engine):
312315
assert state.epoch == max_epochs
313316

314317
@pytest.mark.parametrize(
315-
"terminate_epoch_event, i",
318+
"terminate_epoch_event, i, skip_epoch_completed",
316319
[
317-
(Events.GET_BATCH_STARTED(once=12), 12),
318-
(Events.GET_BATCH_COMPLETED(once=12), 12),
319-
(Events.ITERATION_STARTED(once=14), 14),
320-
(Events.ITERATION_COMPLETED(once=14), 14),
320+
(Events.GET_BATCH_STARTED(once=12), 12, False),
321+
(Events.GET_BATCH_COMPLETED(once=12), 12, False),
322+
(Events.ITERATION_STARTED(once=14), 14, False),
323+
(Events.ITERATION_COMPLETED(once=14), 14, False),
324+
(Events.GET_BATCH_STARTED(once=12), 12, True),
325+
(Events.GET_BATCH_COMPLETED(once=12), 12, True),
326+
(Events.ITERATION_STARTED(once=14), 14, True),
327+
(Events.ITERATION_COMPLETED(once=14), 14, True),
321328
],
322329
)
323-
def test_terminate_epoch_events_sequence(self, terminate_epoch_event, i):
330+
def test_terminate_epoch_events_sequence(self, terminate_epoch_event, i, skip_epoch_completed):
324331
engine = RecordedEngine(MagicMock(return_value=1))
325332
data = range(10)
326333
max_epochs = 3
@@ -331,23 +338,27 @@ def test_terminate_epoch_events_sequence(self, terminate_epoch_event, i):
331338

332339
@engine.on(terminate_epoch_event)
333340
def call_terminate_epoch():
341+
assert not engine._skip_epoch_completed_after_termination
334342
nonlocal call_count
335343
if call_count < 1:
336-
engine.terminate_epoch()
344+
engine.terminate_epoch(skip_epoch_completed)
345+
assert engine._skip_epoch_completed_after_termination == skip_epoch_completed
346+
337347
call_count += 1
338348

339349
@engine.on(Events.TERMINATE_SINGLE_EPOCH)
340350
def check_previous_events(iter_counter):
341351
e = i // len(data) + 1
342-
343352
assert engine.called_events[0] == (0, 0, Events.STARTED)
344353
assert engine.called_events[-2] == (e, i, terminate_epoch_event)
345354
assert engine.called_events[-1] == (e, i, Events.TERMINATE_SINGLE_EPOCH)
355+
assert engine._skip_epoch_completed_after_termination == skip_epoch_completed
346356

347357
@engine.on(Events.EPOCH_COMPLETED)
348358
def check_previous_events2():
349359
e = i // len(data) + 1
350360
if e == engine.state.epoch and i == engine.state.iteration:
361+
assert not skip_epoch_completed
351362
assert engine.called_events[-3] == (e, i, terminate_epoch_event)
352363
assert engine.called_events[-2] == (e, i, Events.TERMINATE_SINGLE_EPOCH)
353364
assert engine.called_events[-1] == (e, i, Events.EPOCH_COMPLETED)
@@ -357,6 +368,9 @@ def check_previous_events2():
357368
assert engine.state.epoch == max_epochs
358369
assert (max_epochs - 1) * len(data) < engine.state.iteration < max_epochs * len(data)
359370

371+
epoch_completed_events = [e for e in engine.called_events if e[2] == Events.EPOCH_COMPLETED.name]
372+
assert len(epoch_completed_events) == max_epochs - skip_epoch_completed
373+
360374
@pytest.mark.parametrize("data", [None, "mock_data_loader"])
361375
def test_iteration_events_are_fired(self, data):
362376
max_epochs = 5

0 commit comments

Comments
 (0)