|
| 1 | +import itertools |
1 | 2 | import os
|
2 | 3 | import re
|
3 | 4 | from abc import ABC, abstractmethod
|
@@ -320,31 +321,30 @@ def __init__(self, root, split="train", pass_name="clean", camera="left", transf
|
320 | 321 |
|
321 | 322 | root = Path(root) / "FlyingThings3D"
|
322 | 323 |
|
323 |
| - for pass_name in passes: |
324 |
| - for camera in cameras: |
325 |
| - for direction in ["into_future", "into_past"]: |
326 |
| - image_dirs = sorted(glob(str(root / pass_name / split / "*/*"))) |
327 |
| - image_dirs = sorted([Path(image_dir) / camera for image_dir in image_dirs]) |
328 |
| - |
329 |
| - flow_dirs = sorted(glob(str(root / "optical_flow" / split / "*/*"))) |
330 |
| - flow_dirs = sorted([Path(flow_dir) / direction / camera for flow_dir in flow_dirs]) |
331 |
| - |
332 |
| - if not image_dirs or not flow_dirs: |
333 |
| - raise FileNotFoundError( |
334 |
| - "Could not find the FlyingThings3D flow images. " |
335 |
| - "Please make sure the directory structure is correct." |
336 |
| - ) |
337 |
| - |
338 |
| - for image_dir, flow_dir in zip(image_dirs, flow_dirs): |
339 |
| - images = sorted(glob(str(image_dir / "*.png"))) |
340 |
| - flows = sorted(glob(str(flow_dir / "*.pfm"))) |
341 |
| - for i in range(len(flows) - 1): |
342 |
| - if direction == "into_future": |
343 |
| - self._image_list += [[images[i], images[i + 1]]] |
344 |
| - self._flow_list += [flows[i]] |
345 |
| - elif direction == "into_past": |
346 |
| - self._image_list += [[images[i + 1], images[i]]] |
347 |
| - self._flow_list += [flows[i + 1]] |
| 324 | + directions = ("into_future", "into_past") |
| 325 | + for pass_name, camera, direction in itertools.product(passes, cameras, directions): |
| 326 | + image_dirs = sorted(glob(str(root / pass_name / split / "*/*"))) |
| 327 | + image_dirs = sorted([Path(image_dir) / camera for image_dir in image_dirs]) |
| 328 | + |
| 329 | + flow_dirs = sorted(glob(str(root / "optical_flow" / split / "*/*"))) |
| 330 | + flow_dirs = sorted([Path(flow_dir) / direction / camera for flow_dir in flow_dirs]) |
| 331 | + |
| 332 | + if not image_dirs or not flow_dirs: |
| 333 | + raise FileNotFoundError( |
| 334 | + "Could not find the FlyingThings3D flow images. " |
| 335 | + "Please make sure the directory structure is correct." |
| 336 | + ) |
| 337 | + |
| 338 | + for image_dir, flow_dir in zip(image_dirs, flow_dirs): |
| 339 | + images = sorted(glob(str(image_dir / "*.png"))) |
| 340 | + flows = sorted(glob(str(flow_dir / "*.pfm"))) |
| 341 | + for i in range(len(flows) - 1): |
| 342 | + if direction == "into_future": |
| 343 | + self._image_list += [[images[i], images[i + 1]]] |
| 344 | + self._flow_list += [flows[i]] |
| 345 | + elif direction == "into_past": |
| 346 | + self._image_list += [[images[i + 1], images[i]]] |
| 347 | + self._flow_list += [flows[i + 1]] |
348 | 348 |
|
349 | 349 | def __getitem__(self, index):
|
350 | 350 | """Return example at given index.
|
|
0 commit comments