@@ -292,16 +292,19 @@ def assert_no_exceptions(ee):
292
292
assert engine .called_events [0 ] == (0 , 0 , Events .STARTED )
293
293
assert engine ._dataloader_iter is None
294
294
295
- @pytest .mark .parametrize ("data, epoch_length" , [(None , 10 ), (range (10 ), None )])
296
- def test_terminate_epoch_stops_mid_epoch (self , data , epoch_length ):
295
+ @pytest .mark .parametrize (
296
+ "data, epoch_length, skip_epoch_completed" ,
297
+ [(None , 10 , False ), (range (10 ), None , False ), (None , 10 , True ), (range (10 ), None , True )],
298
+ )
299
+ def test_terminate_epoch_stops_mid_epoch (self , data , epoch_length , skip_epoch_completed ):
297
300
real_epoch_length = epoch_length if data is None else len (data )
298
301
iteration_to_stop = real_epoch_length + 4
299
302
300
303
engine = Engine (MagicMock (return_value = 1 ))
301
304
302
305
def start_of_iteration_handler (engine ):
303
306
if engine .state .iteration == iteration_to_stop :
304
- engine .terminate_epoch ()
307
+ engine .terminate_epoch (skip_epoch_completed )
305
308
306
309
max_epochs = 3
307
310
engine .add_event_handler (Events .ITERATION_STARTED , start_of_iteration_handler )
@@ -312,15 +315,19 @@ def start_of_iteration_handler(engine):
312
315
assert state .epoch == max_epochs
313
316
314
317
@pytest .mark .parametrize (
315
- "terminate_epoch_event, i" ,
318
+ "terminate_epoch_event, i, skip_epoch_completed " ,
316
319
[
317
- (Events .GET_BATCH_STARTED (once = 12 ), 12 ),
318
- (Events .GET_BATCH_COMPLETED (once = 12 ), 12 ),
319
- (Events .ITERATION_STARTED (once = 14 ), 14 ),
320
- (Events .ITERATION_COMPLETED (once = 14 ), 14 ),
320
+ (Events .GET_BATCH_STARTED (once = 12 ), 12 , False ),
321
+ (Events .GET_BATCH_COMPLETED (once = 12 ), 12 , False ),
322
+ (Events .ITERATION_STARTED (once = 14 ), 14 , False ),
323
+ (Events .ITERATION_COMPLETED (once = 14 ), 14 , False ),
324
+ (Events .GET_BATCH_STARTED (once = 12 ), 12 , True ),
325
+ (Events .GET_BATCH_COMPLETED (once = 12 ), 12 , True ),
326
+ (Events .ITERATION_STARTED (once = 14 ), 14 , True ),
327
+ (Events .ITERATION_COMPLETED (once = 14 ), 14 , True ),
321
328
],
322
329
)
323
- def test_terminate_epoch_events_sequence (self , terminate_epoch_event , i ):
330
+ def test_terminate_epoch_events_sequence (self , terminate_epoch_event , i , skip_epoch_completed ):
324
331
engine = RecordedEngine (MagicMock (return_value = 1 ))
325
332
data = range (10 )
326
333
max_epochs = 3
@@ -331,23 +338,27 @@ def test_terminate_epoch_events_sequence(self, terminate_epoch_event, i):
331
338
332
339
@engine .on (terminate_epoch_event )
333
340
def call_terminate_epoch ():
341
+ assert not engine ._skip_epoch_completed_after_termination
334
342
nonlocal call_count
335
343
if call_count < 1 :
336
- engine .terminate_epoch ()
344
+ engine .terminate_epoch (skip_epoch_completed )
345
+ assert engine ._skip_epoch_completed_after_termination == skip_epoch_completed
346
+
337
347
call_count += 1
338
348
339
349
@engine .on (Events .TERMINATE_SINGLE_EPOCH )
340
350
def check_previous_events (iter_counter ):
341
351
e = i // len (data ) + 1
342
-
343
352
assert engine .called_events [0 ] == (0 , 0 , Events .STARTED )
344
353
assert engine .called_events [- 2 ] == (e , i , terminate_epoch_event )
345
354
assert engine .called_events [- 1 ] == (e , i , Events .TERMINATE_SINGLE_EPOCH )
355
+ assert engine ._skip_epoch_completed_after_termination == skip_epoch_completed
346
356
347
357
@engine .on (Events .EPOCH_COMPLETED )
348
358
def check_previous_events2 ():
349
359
e = i // len (data ) + 1
350
360
if e == engine .state .epoch and i == engine .state .iteration :
361
+ assert not skip_epoch_completed
351
362
assert engine .called_events [- 3 ] == (e , i , terminate_epoch_event )
352
363
assert engine .called_events [- 2 ] == (e , i , Events .TERMINATE_SINGLE_EPOCH )
353
364
assert engine .called_events [- 1 ] == (e , i , Events .EPOCH_COMPLETED )
@@ -357,6 +368,9 @@ def check_previous_events2():
357
368
assert engine .state .epoch == max_epochs
358
369
assert (max_epochs - 1 ) * len (data ) < engine .state .iteration < max_epochs * len (data )
359
370
371
+ epoch_completed_events = [e for e in engine .called_events if e [2 ] == Events .EPOCH_COMPLETED .name ]
372
+ assert len (epoch_completed_events ) == max_epochs - skip_epoch_completed
373
+
360
374
@pytest .mark .parametrize ("data" , [None , "mock_data_loader" ])
361
375
def test_iteration_events_are_fired (self , data ):
362
376
max_epochs = 5
0 commit comments