1
1
import functools
2
2
import logging
3
- import math
4
3
import time
5
4
import warnings
6
5
import weakref
@@ -747,14 +746,13 @@ def load_state_dict(self, state_dict: Mapping) -> None:
747
746
748
747
@staticmethod
749
748
def _is_done (state : State ) -> bool :
750
- is_done_iters = state .max_iters is not None and state .iteration >= state .max_iters
751
749
is_done_count = (
752
750
state .epoch_length is not None
753
751
and state .max_epochs is not None
754
752
and state .iteration >= state .epoch_length * state .max_epochs
755
753
)
756
754
is_done_epochs = state .max_epochs is not None and state .epoch >= state .max_epochs
757
- return is_done_iters or is_done_count or is_done_epochs
755
+ return is_done_count or is_done_epochs
758
756
759
757
def set_data (self , data : Union [Iterable , DataLoader ]) -> None :
760
758
"""Method to set data. After calling the method the next batch passed to `processing_function` is
@@ -796,14 +794,13 @@ def run(
796
794
self ,
797
795
data : Optional [Iterable ] = None ,
798
796
max_epochs : Optional [int ] = None ,
799
- max_iters : Optional [int ] = None ,
800
797
epoch_length : Optional [int ] = None ,
801
798
) -> State :
802
799
"""Runs the ``process_function`` over the passed data.
803
800
804
801
Engine has a state and the following logic is applied in this function:
805
802
806
- - At the first call, new state is defined by `max_epochs`, `max_iters`, ` epoch_length`, if provided.
803
+ - At the first call, new state is defined by `max_epochs`, `epoch_length`, if provided.
807
804
A timer for total and per-epoch time is initialized when Events.STARTED is handled.
808
805
- If state is already defined such that there are iterations to run until `max_epochs` and no input arguments
809
806
provided, state is kept and used in the function.
@@ -821,9 +818,6 @@ def run(
821
818
`len(data)`. If `data` is an iterator and `epoch_length` is not set, then it will be automatically
822
819
determined as the iteration on which data iterator raises `StopIteration`.
823
820
This argument should not change if run is resuming from a state.
824
- max_iters: Number of iterations to run for.
825
- `max_iters` and `max_epochs` are mutually exclusive; only one of the two arguments should be provided.
826
-
827
821
Returns:
828
822
State: output state.
829
823
@@ -874,6 +868,8 @@ def switch_batch(engine):
874
868
875
869
if self .state .max_epochs is None or (self ._is_done (self .state ) and self ._internal_run_generator is None ):
876
870
# Create new state
871
+ if max_epochs is None :
872
+ max_epochs = 1
877
873
if epoch_length is None :
878
874
if data is None :
879
875
raise ValueError ("epoch_length should be provided if data is None" )
@@ -882,22 +878,9 @@ def switch_batch(engine):
882
878
if epoch_length is not None and epoch_length < 1 :
883
879
raise ValueError ("Input data has zero size. Please provide non-empty data" )
884
880
885
- if max_iters is None :
886
- if max_epochs is None :
887
- max_epochs = 1
888
- else :
889
- if max_epochs is not None :
890
- raise ValueError (
891
- "Arguments max_iters and max_epochs are mutually exclusive."
892
- "Please provide only max_epochs or max_iters."
893
- )
894
- if epoch_length is not None :
895
- max_epochs = math .ceil (max_iters / epoch_length )
896
-
897
881
self .state .iteration = 0
898
882
self .state .epoch = 0
899
883
self .state .max_epochs = max_epochs
900
- self .state .max_iters = max_iters
901
884
self .state .epoch_length = epoch_length
902
885
# Reset generator if previously used
903
886
self ._internal_run_generator = None
@@ -1095,18 +1078,12 @@ def _run_once_on_dataset_as_gen(self) -> Generator[State, None, float]:
1095
1078
if self .state .epoch_length is None :
1096
1079
# Define epoch length and stop the epoch
1097
1080
self .state .epoch_length = iter_counter
1098
- if self .state .max_iters is not None :
1099
- self .state .max_epochs = math .ceil (self .state .max_iters / self .state .epoch_length )
1100
1081
break
1101
1082
1102
1083
# Should exit while loop if we can not iterate
1103
1084
if should_exit :
1104
- if not self ._is_done (self .state ):
1105
- total_iters = (
1106
- self .state .epoch_length * self .state .max_epochs
1107
- if self .state .max_epochs is not None
1108
- else self .state .max_iters
1109
- )
1085
+ if not self ._is_done (self .state ) and self .state .max_epochs is not None :
1086
+ total_iters = self .state .epoch_length * self .state .max_epochs
1110
1087
1111
1088
warnings .warn (
1112
1089
"Data iterator can not provide data anymore but required total number of "
@@ -1137,10 +1114,6 @@ def _run_once_on_dataset_as_gen(self) -> Generator[State, None, float]:
1137
1114
if self .state .epoch_length is not None and iter_counter == self .state .epoch_length :
1138
1115
break
1139
1116
1140
- if self .state .max_iters is not None and self .state .iteration == self .state .max_iters :
1141
- self .should_terminate = True
1142
- raise _EngineTerminateException ()
1143
-
1144
1117
except _EngineTerminateSingleEpochException :
1145
1118
self ._fire_event (Events .TERMINATE_SINGLE_EPOCH , iter_counter = iter_counter )
1146
1119
self ._setup_dataloader_iter ()
@@ -1278,18 +1251,12 @@ def _run_once_on_dataset_legacy(self) -> float:
1278
1251
if self .state .epoch_length is None :
1279
1252
# Define epoch length and stop the epoch
1280
1253
self .state .epoch_length = iter_counter
1281
- if self .state .max_iters is not None :
1282
- self .state .max_epochs = math .ceil (self .state .max_iters / self .state .epoch_length )
1283
1254
break
1284
1255
1285
1256
# Should exit while loop if we can not iterate
1286
1257
if should_exit :
1287
- if not self ._is_done (self .state ):
1288
- total_iters = (
1289
- self .state .epoch_length * self .state .max_epochs
1290
- if self .state .max_epochs is not None
1291
- else self .state .max_iters
1292
- )
1258
+ if not self ._is_done (self .state ) and self .state .max_epochs is not None :
1259
+ total_iters = self .state .epoch_length * self .state .max_epochs
1293
1260
1294
1261
warnings .warn (
1295
1262
"Data iterator can not provide data anymore but required total number of "
@@ -1320,10 +1287,6 @@ def _run_once_on_dataset_legacy(self) -> float:
1320
1287
if self .state .epoch_length is not None and iter_counter == self .state .epoch_length :
1321
1288
break
1322
1289
1323
- if self .state .max_iters is not None and self .state .iteration == self .state .max_iters :
1324
- self .should_terminate = True
1325
- raise _EngineTerminateException ()
1326
-
1327
1290
except _EngineTerminateSingleEpochException :
1328
1291
self ._fire_event (Events .TERMINATE_SINGLE_EPOCH , iter_counter = iter_counter )
1329
1292
self ._setup_dataloader_iter ()
0 commit comments