diff --git a/torchvision/prototype/models/optical_flow/raft.py b/torchvision/prototype/models/optical_flow/raft.py index b456852f8b9..a9d93da9b0e 100644 --- a/torchvision/prototype/models/optical_flow/raft.py +++ b/torchvision/prototype/models/optical_flow/raft.py @@ -79,18 +79,30 @@ class Raft_Large_Weights(WeightsEnum): }, ) - # C_T_SKHT_K_V1 = Weights( - # # Chairs + Things + Sintel fine-tuning + Kitti fine-tuning i.e.: - # # Chairs + Things + (Sintel + Kitti + HD1K + Things_clean) + Kitti - # # Same as CT_SKHT with extra fine-tuning on Kitti - # # Corresponds to the C+T+S+K+H on paper with fine-tuning on Sintel and then on Kitti - # url="", - # transforms=RaftEval, - # meta={ - # "recipe": "", - # "epe": -1234, - # }, - # ) + C_T_SKHT_K_V1 = Weights( + # Chairs + Things + Sintel fine-tuning + Kitti fine-tuning, ported from the original repo (sintel-kitti.pth) + url="https://download.pytorch.org/models/raft_large_C_T_SKHT_K_V1-4a6a5039.pth", + transforms=RaftEval, + meta={ + **_COMMON_META, + "recipe": "https://github.com/princeton-vl/RAFT", + "kitti_test_f1-all": 5.10, + }, + ) + + C_T_SKHT_K_V2 = Weights( + # Chairs + Things + Sintel fine-tuning + Kitti fine-tuning i.e.: + # Chairs + Things + (Sintel + Kitti + HD1K + Things_clean) + Kitti + # Same as CT_SKHT with extra fine-tuning on Kitti + # Corresponds to the C+T+S+K+H on paper with fine-tuning on Sintel and then on Kitti + url="https://download.pytorch.org/models/raft_large_C_T_SKHT_K_V2-b5c70766.pth", + transforms=RaftEval, + meta={ + **_COMMON_META, + "recipe": "https://github.com/pytorch/vision/tree/main/references/optical_flow", + "kitti_test_f1-all": 5.19, + }, + ) default = C_T_V2