Skip to content

Commit 55ddb93

Browse files
authored
Merge branch 'main' into prototype/preprocessing_refs
2 parents a0654dd + 140322f commit 55ddb93

File tree

7 files changed

+90
-3
lines changed

7 files changed

+90
-3
lines changed

.circleci/config.yml

Lines changed: 3 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

.circleci/config.yml.in

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,9 @@ jobs:
277277
command: pip install --user --progress-bar=off pytest pytest-mock
278278
- run:
279279
name: Run tests
280-
command: pytest test/test_prototype_*.py
280+
command: pytest --junitxml=test-results/junit.xml -v --durations 20 test/test_prototype_*.py
281+
- store_test_results:
282+
path: test-results
281283

282284
binary_linux_wheel:
283285
<<: *binary_common

references/classification/train.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,10 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, arg
4242
else:
4343
loss = criterion(output, target)
4444
loss.backward()
45+
46+
if args.clip_grad_norm is not None:
47+
nn.utils.clip_grad_norm_(utils.get_optimizer_params(optimizer), args.clip_grad_norm)
48+
4549
optimizer.step()
4650

4751
if model_ema and i % args.model_ema_steps == 0:
@@ -472,6 +476,7 @@ def get_args_parser(add_help=True):
472476
parser.add_argument(
473477
"--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)"
474478
)
479+
parser.add_argument("--clip-grad-norm", default=None, type=float, help="the maximum gradient norm (default None)")
475480

476481
# Prototype models only
477482
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")

references/classification/utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -409,3 +409,11 @@ def reduce_across_processes(val):
409409
dist.barrier()
410410
dist.all_reduce(t)
411411
return t
412+
413+
414+
def get_optimizer_params(optimizer):
415+
"""Generator to iterate over all parameters in the optimizer param_groups."""
416+
417+
for group in optimizer.param_groups:
418+
for p in group["params"]:
419+
yield p

torchvision/prototype/datasets/_builtin/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,5 @@
55
from .imagenet import ImageNet
66
from .mnist import MNIST, FashionMNIST, KMNIST, EMNIST, QMNIST
77
from .sbd import SBD
8+
from .semeion import SEMEION
89
from .voc import VOC
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import io
2+
from typing import Any, Callable, Dict, List, Optional, Tuple
3+
4+
import torch
5+
from torchdata.datapipes.iter import (
6+
IterDataPipe,
7+
Mapper,
8+
Shuffler,
9+
CSVParser,
10+
)
11+
from torchvision.prototype.datasets.decoder import raw
12+
from torchvision.prototype.datasets.utils import (
13+
Dataset,
14+
DatasetConfig,
15+
DatasetInfo,
16+
HttpResource,
17+
OnlineResource,
18+
DatasetType,
19+
)
20+
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, image_buffer_from_array
21+
22+
23+
class SEMEION(Dataset):
24+
def _make_info(self) -> DatasetInfo:
25+
return DatasetInfo(
26+
"semeion",
27+
type=DatasetType.RAW,
28+
categories=10,
29+
homepage="https://archive.ics.uci.edu/ml/datasets/Semeion+Handwritten+Digit",
30+
)
31+
32+
def resources(self, config: DatasetConfig) -> List[OnlineResource]:
33+
archive = HttpResource(
34+
"http://archive.ics.uci.edu/ml/machine-learning-databases/semeion/semeion.data",
35+
sha256="f43228ae3da5ea6a3c95069d53450b86166770e3b719dcc333182128fe08d4b1",
36+
)
37+
return [archive]
38+
39+
def _collate_and_decode_sample(
40+
self,
41+
data: Tuple[str, ...],
42+
*,
43+
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
44+
) -> Dict[str, Any]:
45+
image_data = torch.tensor([float(pixel) for pixel in data[:256]], dtype=torch.uint8).reshape(16, 16)
46+
label_data = [int(label) for label in data[256:] if label]
47+
48+
if decoder is raw:
49+
image = image_data.unsqueeze(0)
50+
else:
51+
image_buffer = image_buffer_from_array(image_data.numpy())
52+
image = decoder(image_buffer) if decoder else image_buffer # type: ignore[assignment]
53+
54+
label = next((idx for idx, one_hot_label in enumerate(label_data) if one_hot_label))
55+
category = self.info.categories[label]
56+
return dict(image=image, label=label, category=category)
57+
58+
def _make_datapipe(
59+
self,
60+
resource_dps: List[IterDataPipe],
61+
*,
62+
config: DatasetConfig,
63+
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
64+
) -> IterDataPipe[Dict[str, Any]]:
65+
dp = resource_dps[0]
66+
dp = CSVParser(dp, delimiter=" ")
67+
dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE)
68+
dp = Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder))
69+
return dp

torchvision/prototype/datasets/generate_category_files.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def parse_args(argv=None):
5252

5353

5454
if __name__ == "__main__":
55-
args = parse_args(["-f", "sbd"])
55+
args = parse_args()
5656

5757
try:
5858
main(*args.names, force=args.force)

0 commit comments

Comments
 (0)