|
18 | 18 | amp = None
|
19 | 19 |
|
20 | 20 |
|
| 21 | +try: |
| 22 | + from torchvision.prototype import models as PM |
| 23 | +except ImportError: |
| 24 | + PM = None |
| 25 | + |
| 26 | + |
21 | 27 | def train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, device, epoch, print_freq, apex=False):
|
22 | 28 | model.train()
|
23 | 29 | metric_logger = utils.MetricLogger(delimiter=" ")
|
@@ -149,7 +155,12 @@ def main(args):
|
149 | 155 | print("Loading validation data")
|
150 | 156 | cache_path = _get_cache_path(valdir)
|
151 | 157 |
|
152 |
| - transform_test = presets.VideoClassificationPresetEval((128, 171), (112, 112)) |
| 158 | + if not args.weights: |
| 159 | + transform_test = presets.VideoClassificationPresetEval((128, 171), (112, 112)) |
| 160 | + else: |
| 161 | + fn = PM.video.__dict__[args.model] |
| 162 | + weights = PM._api.get_weight(fn, args.weights) |
| 163 | + transform_test = weights.transforms() |
153 | 164 |
|
154 | 165 | if args.cache_dataset and os.path.exists(cache_path):
|
155 | 166 | print(f"Loading dataset_test from {cache_path}")
|
@@ -200,7 +211,12 @@ def main(args):
|
200 | 211 | )
|
201 | 212 |
|
202 | 213 | print("Creating model")
|
203 |
| - model = torchvision.models.video.__dict__[args.model](pretrained=args.pretrained) |
| 214 | + if not args.weights: |
| 215 | + model = torchvision.models.video.__dict__[args.model](pretrained=args.pretrained) |
| 216 | + else: |
| 217 | + if PM is None: |
| 218 | + raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.") |
| 219 | + model = PM.video.__dict__[args.model](weights=args.weights) |
204 | 220 | model.to(device)
|
205 | 221 | if args.distributed and args.sync_bn:
|
206 | 222 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
@@ -363,6 +379,9 @@ def parse_args():
|
363 | 379 | parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
|
364 | 380 | parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")
|
365 | 381 |
|
| 382 | + # Prototype models only |
| 383 | + parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load") |
| 384 | + |
366 | 385 | args = parser.parse_args()
|
367 | 386 |
|
368 | 387 | return args
|
|
0 commit comments