From 347c61214c610fb77a8fc61ce2b33cf5f8b1da98 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 9 Dec 2021 11:33:24 +0000 Subject: [PATCH 1/5] WIP --- references/optical_flow/train.py | 3 +- .../prototype/models/optical_flow/raft.py | 34 +++++++++++++------ 2 files changed, 25 insertions(+), 12 deletions(-) diff --git a/references/optical_flow/train.py b/references/optical_flow/train.py index 326f0be5f66..254d01ed226 100644 --- a/references/optical_flow/train.py +++ b/references/optical_flow/train.py @@ -210,7 +210,8 @@ def main(args): if args.resume is not None: d = torch.load(args.resume, map_location="cpu") - model.load_state_dict(d, strict=True) + # model.load_state_dict(d, strict=True) + model.module.load_state_dict(d, strict=True) if args.train_dataset is None: # Set deterministic CUDNN algorithms, since they can affect epe a fair bit. diff --git a/torchvision/prototype/models/optical_flow/raft.py b/torchvision/prototype/models/optical_flow/raft.py index b1b5fcbe911..e0c79da6a21 100644 --- a/torchvision/prototype/models/optical_flow/raft.py +++ b/torchvision/prototype/models/optical_flow/raft.py @@ -49,17 +49,29 @@ class Raft_Large_Weights(WeightsEnum): }, ) - # C_T_SKHT_V1 = Weights( - # # Chairs + Things + Sintel fine-tuning, i.e.: - # # Chairs + Things + (Sintel + Kitti + HD1K + Things_clean) - # # Corresponds to the C+T+S+K+H on paper with fine-tuning on Sintel - # url="", - # transforms=RaftEval, - # meta={ - # "recipe": "", - # "epe": -1234, - # }, - # ) + C_T_SKHT_V1 = Weights( + # Chairs + Things + Sintel fine-tuning, ported from original paper repo (raft-sintel.pth) + url="https://download.pytorch.org/models/raft_large_C_T_SKHT_V1-0b8c9e55.pth", + transforms=RaftEval, + meta={ + "recipe": "https://github.com/pytorch/vision/tree/main/references/optical_flow", + "sintel_test_cleanpass_epe": 1.94, + "sintel_test_finalpass_epe": 3.18, + }, + ) + + C_T_SKHT_V2 = Weights( + # Chairs + Things + Sintel fine-tuning, i.e.: + # Chairs + Things + (Sintel + Kitti + HD1K + Things_clean) + # Corresponds to the C+T+S+K+H on paper with fine-tuning on Sintel + url="https://download.pytorch.org/models/raft_large_C_T_SKHT_V2-ff5fadd5.pth", + transforms=RaftEval, + meta={ + "recipe": "https://github.com/pytorch/vision/tree/main/references/optical_flow", + "sintel_test_cleanpass_epe": 1.819, + "sintel_test_finalpass_epe": 3.067, + }, + ) # C_T_SKHT_K_V1 = Weights( # # Chairs + Things + Sintel fine-tuning + Kitti fine-tuning i.e.: From acffc50ef1a69a13639f3837379463688af53f01 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 9 Dec 2021 11:36:57 +0000 Subject: [PATCH 2/5] WIP --- torchvision/prototype/models/optical_flow/raft.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torchvision/prototype/models/optical_flow/raft.py b/torchvision/prototype/models/optical_flow/raft.py index 389259061a5..ceb8283ac39 100644 --- a/torchvision/prototype/models/optical_flow/raft.py +++ b/torchvision/prototype/models/optical_flow/raft.py @@ -58,7 +58,8 @@ class Raft_Large_Weights(WeightsEnum): url="https://download.pytorch.org/models/raft_large_C_T_SKHT_V1-0b8c9e55.pth", transforms=RaftEval, meta={ - "recipe": "https://github.com/pytorch/vision/tree/main/references/optical_flow", + **_COMMON_META, + "recipe": "https://github.com/princeton-vl/RAFT", "sintel_test_cleanpass_epe": 1.94, "sintel_test_finalpass_epe": 3.18, }, @@ -71,6 +72,7 @@ class Raft_Large_Weights(WeightsEnum): url="https://download.pytorch.org/models/raft_large_C_T_SKHT_V2-ff5fadd5.pth", transforms=RaftEval, meta={ + **_COMMON_META, "recipe": "https://github.com/pytorch/vision/tree/main/references/optical_flow", "sintel_test_cleanpass_epe": 1.819, "sintel_test_finalpass_epe": 3.067, From d3e3f1ffa9786a0bf06d5d88a0a619093c491e8f Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 9 Dec 2021 11:37:57 +0000 Subject: [PATCH 3/5] WIP --- references/optical_flow/train.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/references/optical_flow/train.py b/references/optical_flow/train.py index 254d01ed226..326f0be5f66 100644 --- a/references/optical_flow/train.py +++ b/references/optical_flow/train.py @@ -210,8 +210,7 @@ def main(args): if args.resume is not None: d = torch.load(args.resume, map_location="cpu") - # model.load_state_dict(d, strict=True) - model.module.load_state_dict(d, strict=True) + model.load_state_dict(d, strict=True) if args.train_dataset is None: # Set deterministic CUDNN algorithms, since they can affect epe a fair bit. From 5fea06426b74fe714a0ca7c31e8b76890dda4c71 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 9 Dec 2021 11:45:03 +0000 Subject: [PATCH 4/5] Fix download link for raft_small --- torchvision/prototype/models/optical_flow/raft.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/prototype/models/optical_flow/raft.py b/torchvision/prototype/models/optical_flow/raft.py index ca4ae90927e..1b8343badeb 100644 --- a/torchvision/prototype/models/optical_flow/raft.py +++ b/torchvision/prototype/models/optical_flow/raft.py @@ -97,11 +97,11 @@ class Raft_Small_Weights(WeightsEnum): ) C_T_V2 = Weights( # Chairs + Things - url="https://github.com/pytorch/vision/tree/main/references/optical_flow", + url="https://download.pytorch.org/models/raft_small_C_T_V2-01064c6d.pth", transforms=RaftEval, meta={ **_COMMON_META, - "recipe": "https://github.com/princeton-vl/RAFT", + "recipe": "https://github.com/pytorch/vision/tree/main/references/optical_flow", "sintel_train_cleanpass_epe": 1.9901, "sintel_train_finalpass_epe": 3.2831, "kitti_train_per_image_epe": 7.5978, From bfba31fadf7df7e6d9eef9eb46038c1e53b49495 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 9 Dec 2021 11:48:21 +0000 Subject: [PATCH 5/5] add new line in readme --- references/optical_flow/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/references/optical_flow/README.md b/references/optical_flow/README.md index f722b70ae41..9b08553708a 100644 --- a/references/optical_flow/README.md +++ b/references/optical_flow/README.md @@ -62,4 +62,4 @@ You can also evaluate on Kitti train: ``` torchrun --nproc_per_node 1 --nnodes 1 train.py --val-dataset kitti --batch-size 1 --dataset-root $dataset_root --model raft_large --pretrained Kitti val epe: 4.7968 1px: 0.6388 3px: 0.8197 5px: 0.8661 per_image_epe: 4.5118 f1: 16.0679 -``` \ No newline at end of file +```