Skip to content

Commit a0654dd

Browse files
committed
Adding prototype preprocessing on video references.
1 parent 77c80f5 commit a0654dd

File tree

1 file changed

+21
-2
lines changed

1 file changed

+21
-2
lines changed

references/video_classification/train.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,12 @@
1818
amp = None
1919

2020

21+
try:
22+
from torchvision.prototype import models as PM
23+
except ImportError:
24+
PM = None
25+
26+
2127
def train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, device, epoch, print_freq, apex=False):
2228
model.train()
2329
metric_logger = utils.MetricLogger(delimiter=" ")
@@ -149,7 +155,12 @@ def main(args):
149155
print("Loading validation data")
150156
cache_path = _get_cache_path(valdir)
151157

152-
transform_test = presets.VideoClassificationPresetEval((128, 171), (112, 112))
158+
if not args.weights:
159+
transform_test = presets.VideoClassificationPresetEval((128, 171), (112, 112))
160+
else:
161+
fn = PM.video.__dict__[args.model]
162+
weights = PM._api.get_weight(fn, args.weights)
163+
transform_test = weights.transforms()
153164

154165
if args.cache_dataset and os.path.exists(cache_path):
155166
print(f"Loading dataset_test from {cache_path}")
@@ -200,7 +211,12 @@ def main(args):
200211
)
201212

202213
print("Creating model")
203-
model = torchvision.models.video.__dict__[args.model](pretrained=args.pretrained)
214+
if not args.weights:
215+
model = torchvision.models.video.__dict__[args.model](pretrained=args.pretrained)
216+
else:
217+
if PM is None:
218+
raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.")
219+
model = PM.video.__dict__[args.model](weights=args.weights)
204220
model.to(device)
205221
if args.distributed and args.sync_bn:
206222
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
@@ -363,6 +379,9 @@ def parse_args():
363379
parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
364380
parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")
365381

382+
# Prototype models only
383+
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
384+
366385
args = parser.parse_args()
367386

368387
return args

0 commit comments

Comments
 (0)