Skip to content

Commit f95b053

Browse files
authored
Updated video classification ref example with new transforms (#2935)
* [WIP] Update ref example video classification * [WIP] Updated video classification ref example * Replaced mem format conversion functions by classes
1 parent 044fcf2 commit f95b053

File tree

3 files changed

+61
-124
lines changed

3 files changed

+61
-124
lines changed
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# Video Classification
2+
3+
TODO: Add some info about the context, dataset we use etc
4+
5+
## Data preparation
6+
7+
If you already have downloaded [Kinetics400 dataset](https://deepmind.com/research/open-source/kinetics),
8+
please proceed directly to the next section.
9+
10+
To download videos, one can use https://github.com/Showmax/kinetics-downloader
11+
12+
## Training
13+
14+
We assume the training and validation AVI videos are stored at `/data/kinectics400/train` and
15+
`/data/kinectics400/val`.
16+
17+
### Multiple GPUs
18+
19+
Run the training on a single node with 8 GPUs:
20+
```bash
21+
python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py --data-path=/data/kinectics400 --train-dir=train --val-dir=val --batch-size=16 --cache-dataset --sync-bn --apex
22+
```
23+
24+
25+
26+
### Single GPU
27+
28+
**Note:** training on a single gpu can be extremely slow.
29+
30+
31+
```bash
32+
python train.py --data-path=/data/kinectics400 --train-dir=train --val-dir=val --batch-size=8 --cache-dataset
33+
```
34+
35+

references/video_classification/train.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,13 @@
77
from torch import nn
88
import torchvision
99
import torchvision.datasets.video_utils
10-
from torchvision import transforms
10+
from torchvision import transforms as T
1111
from torchvision.datasets.samplers import DistributedSampler, UniformClipSampler, RandomClipSampler
1212

1313
import utils
1414

1515
from scheduler import WarmupMultiStepLR
16-
import transforms as T
16+
from transforms import ConvertBHWCtoBCHW, ConvertBCHWtoCBHW
1717

1818
try:
1919
from apex import amp
@@ -119,11 +119,13 @@ def main(args):
119119
st = time.time()
120120
cache_path = _get_cache_path(traindir)
121121
transform_train = torchvision.transforms.Compose([
122-
T.ToFloatTensorInZeroOne(),
122+
ConvertBHWCtoBCHW(),
123+
T.ConvertImageDtype(torch.float32),
123124
T.Resize((128, 171)),
124125
T.RandomHorizontalFlip(),
125126
normalize,
126-
T.RandomCrop((112, 112))
127+
T.RandomCrop((112, 112)),
128+
ConvertBCHWtoCBHW()
127129
])
128130

129131
if args.cache_dataset and os.path.exists(cache_path):
@@ -139,7 +141,8 @@ def main(args):
139141
frames_per_clip=args.clip_len,
140142
step_between_clips=1,
141143
transform=transform_train,
142-
frame_rate=15
144+
frame_rate=15,
145+
extensions=('avi', 'mp4', )
143146
)
144147
if args.cache_dataset:
145148
print("Saving dataset_train to {}".format(cache_path))
@@ -152,10 +155,12 @@ def main(args):
152155
cache_path = _get_cache_path(valdir)
153156

154157
transform_test = torchvision.transforms.Compose([
155-
T.ToFloatTensorInZeroOne(),
158+
ConvertBHWCtoBCHW(),
159+
T.ConvertImageDtype(torch.float32),
156160
T.Resize((128, 171)),
157161
normalize,
158-
T.CenterCrop((112, 112))
162+
T.CenterCrop((112, 112)),
163+
ConvertBCHWtoCBHW()
159164
])
160165

161166
if args.cache_dataset and os.path.exists(cache_path):
@@ -171,7 +176,8 @@ def main(args):
171176
frames_per_clip=args.clip_len,
172177
step_between_clips=1,
173178
transform=transform_test,
174-
frame_rate=15
179+
frame_rate=15,
180+
extensions=('avi', 'mp4',)
175181
)
176182
if args.cache_dataset:
177183
print("Saving dataset_test to {}".format(cache_path))
@@ -265,7 +271,7 @@ def main(args):
265271

266272
def parse_args():
267273
import argparse
268-
parser = argparse.ArgumentParser(description='PyTorch Classification Training')
274+
parser = argparse.ArgumentParser(description='PyTorch Video Classification Training')
269275

270276
parser.add_argument('--data-path', default='/datasets01_101/kinetics/070618/', help='dataset')
271277
parser.add_argument('--train-dir', default='train_avi-480p', help='name of train dir')
Lines changed: 11 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -1,122 +1,18 @@
11
import torch
2-
import random
2+
import torch.nn as nn
33

44

5-
def crop(vid, i, j, h, w):
6-
return vid[..., i:(i + h), j:(j + w)]
5+
class ConvertBHWCtoBCHW(nn.Module):
6+
"""Convert tensor from (B, H, W, C) to (B, C, H, W)
7+
"""
78

9+
def forward(self, vid: torch.Tensor) -> torch.Tensor:
10+
return vid.permute(0, 3, 1, 2)
811

9-
def center_crop(vid, output_size):
10-
h, w = vid.shape[-2:]
11-
th, tw = output_size
1212

13-
i = int(round((h - th) / 2.))
14-
j = int(round((w - tw) / 2.))
15-
return crop(vid, i, j, th, tw)
13+
class ConvertBCHWtoCBHW(nn.Module):
14+
"""Convert tensor from (B, C, H, W) to (C, B, H, W)
15+
"""
1616

17-
18-
def hflip(vid):
19-
return vid.flip(dims=(-1,))
20-
21-
22-
# NOTE: for those functions, which generally expect mini-batches, we keep them
23-
# as non-minibatch so that they are applied as if they were 4d (thus image).
24-
# this way, we only apply the transformation in the spatial domain
25-
def resize(vid, size, interpolation='bilinear'):
26-
# NOTE: using bilinear interpolation because we don't work on minibatches
27-
# at this level
28-
scale = None
29-
if isinstance(size, int):
30-
scale = float(size) / min(vid.shape[-2:])
31-
size = None
32-
return torch.nn.functional.interpolate(
33-
vid, size=size, scale_factor=scale, mode=interpolation, align_corners=False)
34-
35-
36-
def pad(vid, padding, fill=0, padding_mode="constant"):
37-
# NOTE: don't want to pad on temporal dimension, so let as non-batch
38-
# (4d) before padding. This works as expected
39-
return torch.nn.functional.pad(vid, padding, value=fill, mode=padding_mode)
40-
41-
42-
def to_normalized_float_tensor(vid):
43-
return vid.permute(3, 0, 1, 2).to(torch.float32) / 255
44-
45-
46-
def normalize(vid, mean, std):
47-
shape = (-1,) + (1,) * (vid.dim() - 1)
48-
mean = torch.as_tensor(mean).reshape(shape)
49-
std = torch.as_tensor(std).reshape(shape)
50-
return (vid - mean) / std
51-
52-
53-
# Class interface
54-
55-
class RandomCrop(object):
56-
def __init__(self, size):
57-
self.size = size
58-
59-
@staticmethod
60-
def get_params(vid, output_size):
61-
"""Get parameters for ``crop`` for a random crop.
62-
"""
63-
h, w = vid.shape[-2:]
64-
th, tw = output_size
65-
if w == tw and h == th:
66-
return 0, 0, h, w
67-
i = random.randint(0, h - th)
68-
j = random.randint(0, w - tw)
69-
return i, j, th, tw
70-
71-
def __call__(self, vid):
72-
i, j, h, w = self.get_params(vid, self.size)
73-
return crop(vid, i, j, h, w)
74-
75-
76-
class CenterCrop(object):
77-
def __init__(self, size):
78-
self.size = size
79-
80-
def __call__(self, vid):
81-
return center_crop(vid, self.size)
82-
83-
84-
class Resize(object):
85-
def __init__(self, size):
86-
self.size = size
87-
88-
def __call__(self, vid):
89-
return resize(vid, self.size)
90-
91-
92-
class ToFloatTensorInZeroOne(object):
93-
def __call__(self, vid):
94-
return to_normalized_float_tensor(vid)
95-
96-
97-
class Normalize(object):
98-
def __init__(self, mean, std):
99-
self.mean = mean
100-
self.std = std
101-
102-
def __call__(self, vid):
103-
return normalize(vid, self.mean, self.std)
104-
105-
106-
class RandomHorizontalFlip(object):
107-
def __init__(self, p=0.5):
108-
self.p = p
109-
110-
def __call__(self, vid):
111-
if random.random() < self.p:
112-
return hflip(vid)
113-
return vid
114-
115-
116-
class Pad(object):
117-
def __init__(self, padding, fill=0):
118-
self.padding = padding
119-
self.fill = fill
120-
121-
def __call__(self, vid):
122-
return pad(vid, self.padding, self.fill)
17+
def forward(self, vid: torch.Tensor) -> torch.Tensor:
18+
return vid.permute(1, 0, 2, 3)

0 commit comments

Comments
 (0)