Skip to content

Commit ba4cfd6

Browse files
committed
itertools.prodcut
1 parent ff4af14 commit ba4cfd6

File tree

1 file changed

+25
-25
lines changed

1 file changed

+25
-25
lines changed

torchvision/datasets/_optical_flow.py

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import itertools
12
import os
23
import re
34
from abc import ABC, abstractmethod
@@ -320,31 +321,30 @@ def __init__(self, root, split="train", pass_name="clean", camera="left", transf
320321

321322
root = Path(root) / "FlyingThings3D"
322323

323-
for pass_name in passes:
324-
for camera in cameras:
325-
for direction in ["into_future", "into_past"]:
326-
image_dirs = sorted(glob(str(root / pass_name / split / "*/*")))
327-
image_dirs = sorted([Path(image_dir) / camera for image_dir in image_dirs])
328-
329-
flow_dirs = sorted(glob(str(root / "optical_flow" / split / "*/*")))
330-
flow_dirs = sorted([Path(flow_dir) / direction / camera for flow_dir in flow_dirs])
331-
332-
if not image_dirs or not flow_dirs:
333-
raise FileNotFoundError(
334-
"Could not find the FlyingThings3D flow images. "
335-
"Please make sure the directory structure is correct."
336-
)
337-
338-
for image_dir, flow_dir in zip(image_dirs, flow_dirs):
339-
images = sorted(glob(str(image_dir / "*.png")))
340-
flows = sorted(glob(str(flow_dir / "*.pfm")))
341-
for i in range(len(flows) - 1):
342-
if direction == "into_future":
343-
self._image_list += [[images[i], images[i + 1]]]
344-
self._flow_list += [flows[i]]
345-
elif direction == "into_past":
346-
self._image_list += [[images[i + 1], images[i]]]
347-
self._flow_list += [flows[i + 1]]
324+
directions = ("into_future", "into_past")
325+
for pass_name, camera, direction in itertools.product(passes, cameras, directions):
326+
image_dirs = sorted(glob(str(root / pass_name / split / "*/*")))
327+
image_dirs = sorted([Path(image_dir) / camera for image_dir in image_dirs])
328+
329+
flow_dirs = sorted(glob(str(root / "optical_flow" / split / "*/*")))
330+
flow_dirs = sorted([Path(flow_dir) / direction / camera for flow_dir in flow_dirs])
331+
332+
if not image_dirs or not flow_dirs:
333+
raise FileNotFoundError(
334+
"Could not find the FlyingThings3D flow images. "
335+
"Please make sure the directory structure is correct."
336+
)
337+
338+
for image_dir, flow_dir in zip(image_dirs, flow_dirs):
339+
images = sorted(glob(str(image_dir / "*.png")))
340+
flows = sorted(glob(str(flow_dir / "*.pfm")))
341+
for i in range(len(flows) - 1):
342+
if direction == "into_future":
343+
self._image_list += [[images[i], images[i + 1]]]
344+
self._flow_list += [flows[i]]
345+
elif direction == "into_past":
346+
self._image_list += [[images[i + 1], images[i]]]
347+
self._flow_list += [flows[i + 1]]
348348

349349
def __getitem__(self, index):
350350
"""Return example at given index.

0 commit comments

Comments
 (0)