From 7184dfd72012231e0bbe8c5b49629dcbc6f1d50b Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 1 Dec 2020 20:49:37 +0000 Subject: [PATCH 1/5] Renaming C++ files & methods according to recommended naming conventions and aligning them with Python's API. Syncing, where possible, the names of functions across devices. --- ...SROIAlign_cpu.cpp => ps_roi_align_cpu.cpp} | 20 +++++++++---------- torchvision/csrc/cpu/vision_cpu.h | 4 ++-- ...SROIAlign_cuda.cu => ps_roi_align_cuda.cu} | 20 +++++++++---------- torchvision/csrc/cuda/vision_cuda.h | 4 ++-- .../csrc/{PSROIAlign.h => ps_roi_align.h} | 6 +++--- torchvision/csrc/vision.cpp | 16 +++++++-------- 6 files changed, 35 insertions(+), 35 deletions(-) rename torchvision/csrc/cpu/{PSROIAlign_cpu.cpp => ps_roi_align_cpu.cpp} (95%) rename torchvision/csrc/cuda/{PSROIAlign_cuda.cu => ps_roi_align_cuda.cu} (95%) rename torchvision/csrc/{PSROIAlign.h => ps_roi_align.h} (97%) diff --git a/torchvision/csrc/cpu/PSROIAlign_cpu.cpp b/torchvision/csrc/cpu/ps_roi_align_cpu.cpp similarity index 95% rename from torchvision/csrc/cpu/PSROIAlign_cpu.cpp rename to torchvision/csrc/cpu/ps_roi_align_cpu.cpp index 899dbb208b6..b25115701a2 100644 --- a/torchvision/csrc/cpu/PSROIAlign_cpu.cpp +++ b/torchvision/csrc/cpu/ps_roi_align_cpu.cpp @@ -57,7 +57,7 @@ T bilinear_interpolate( } template -void PSROIAlignForwardCPU( +void ps_roi_align_forward_kernel_impl( int nthreads, const T* input, const T spatial_scale, @@ -202,7 +202,7 @@ inline void add(T* address, const T& val) { } template -void PSROIAlignBackwardCPU( +void ps_roi_align_backward_kernel_impl( int nthreads, const T* grad_output, const int* channel_mapping, @@ -298,7 +298,7 @@ void PSROIAlignBackwardCPU( } } -std::tuple PSROIAlign_forward_cpu( +std::tuple ps_roi_align_forward_cpu( const at::Tensor& input, const at::Tensor& rois, double spatial_scale, @@ -313,7 +313,7 @@ std::tuple PSROIAlign_forward_cpu( at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2}; - at::CheckedFrom c = "PSROIAlign_forward_cpu"; + at::CheckedFrom c = "ps_roi_align_forward_cpu"; at::checkAllSameType(c, {input_t, rois_t}); int num_rois = rois.size(0); @@ -338,8 +338,8 @@ std::tuple PSROIAlign_forward_cpu( auto input_ = input.contiguous(), rois_ = rois.contiguous(); AT_DISPATCH_FLOATING_TYPES_AND_HALF( - input.scalar_type(), "PSROIAlign_forward", [&] { - PSROIAlignForwardCPU( + input.scalar_type(), "ps_roi_align_forward", [&] { + ps_roi_align_forward_kernel_impl( output_size, input_.data_ptr(), spatial_scale, @@ -357,7 +357,7 @@ std::tuple PSROIAlign_forward_cpu( return std::make_tuple(output, channel_mapping); } -at::Tensor PSROIAlign_backward_cpu( +at::Tensor ps_roi_align_backward_cpu( const at::Tensor& grad, const at::Tensor& rois, const at::Tensor& channel_mapping, @@ -379,7 +379,7 @@ at::Tensor PSROIAlign_backward_cpu( at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2}, channel_mapping_t{channel_mapping, "channel_mapping", 3}; - at::CheckedFrom c = "PSROIAlign_backward_cpu"; + at::CheckedFrom c = "ps_roi_align_backward_cpu"; at::checkAllSameType(c, {grad_t, rois_t}); auto num_rois = rois.size(0); @@ -395,8 +395,8 @@ at::Tensor PSROIAlign_backward_cpu( auto grad_ = grad.contiguous(), rois_ = rois.contiguous(); AT_DISPATCH_FLOATING_TYPES_AND_HALF( - grad.scalar_type(), "PSROIAlign_backward", [&] { - PSROIAlignBackwardCPU( + grad.scalar_type(), "ps_roi_align_backward", [&] { + ps_roi_align_backward_kernel_impl( grad.numel(), grad_.data_ptr(), channel_mapping.data_ptr(), diff --git a/torchvision/csrc/cpu/vision_cpu.h b/torchvision/csrc/cpu/vision_cpu.h index 39d89bf6515..db15c7172a4 100644 --- a/torchvision/csrc/cpu/vision_cpu.h +++ b/torchvision/csrc/cpu/vision_cpu.h @@ -4,7 +4,7 @@ // TODO: Delete this file once all the methods are gone -VISION_API std::tuple PSROIAlign_forward_cpu( +VISION_API std::tuple ps_roi_align_forward_cpu( const at::Tensor& input, const at::Tensor& rois, double spatial_scale, @@ -12,7 +12,7 @@ VISION_API std::tuple PSROIAlign_forward_cpu( int64_t pooled_width, int64_t sampling_ratio); -VISION_API at::Tensor PSROIAlign_backward_cpu( +VISION_API at::Tensor ps_roi_align_backward_cpu( const at::Tensor& grad, const at::Tensor& rois, const at::Tensor& channel_mapping, diff --git a/torchvision/csrc/cuda/PSROIAlign_cuda.cu b/torchvision/csrc/cuda/ps_roi_align_cuda.cu similarity index 95% rename from torchvision/csrc/cuda/PSROIAlign_cuda.cu rename to torchvision/csrc/cuda/ps_roi_align_cuda.cu index e6912d8c7ee..cfb4915bd76 100644 --- a/torchvision/csrc/cuda/PSROIAlign_cuda.cu +++ b/torchvision/csrc/cuda/ps_roi_align_cuda.cu @@ -62,7 +62,7 @@ __device__ T bilinear_interpolate( } template -__global__ void PSROIAlignForwardCUDA( +__global__ void ps_roi_align_forward_kernel_impl( int nthreads, const T* input, const T spatial_scale, @@ -195,7 +195,7 @@ __device__ void bilinear_interpolate_gradient( } template -__global__ void PSROIAlignBackwardCUDA( +__global__ void ps_roi_align_backward_kernel_impl( int nthreads, const T* grad_output, const int* channel_mapping, @@ -292,7 +292,7 @@ __global__ void PSROIAlignBackwardCUDA( } } -std::tuple PSROIAlign_forward_cuda( +std::tuple ps_roi_align_forward_cuda( const at::Tensor& input, const at::Tensor& rois, double spatial_scale, @@ -307,7 +307,7 @@ std::tuple PSROIAlign_forward_cuda( at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2}; - at::CheckedFrom c = "PSROIAlign_forward_cuda"; + at::CheckedFrom c = "ps_roi_align_forward_cuda"; at::checkAllSameGPU(c, {input_t, rois_t}); at::checkAllSameType(c, {input_t, rois_t}); @@ -344,8 +344,8 @@ std::tuple PSROIAlign_forward_cuda( auto input_ = input.contiguous(), rois_ = rois.contiguous(); AT_DISPATCH_FLOATING_TYPES_AND_HALF( - input.scalar_type(), "PSROIAlign_forward", [&] { - PSROIAlignForwardCUDA<<>>( + input.scalar_type(), "ps_roi_align_forward", [&] { + ps_roi_align_forward_kernel_impl<<>>( output_size, input_.data_ptr(), spatial_scale, @@ -365,7 +365,7 @@ std::tuple PSROIAlign_forward_cuda( return std::make_tuple(output, channel_mapping); } -at::Tensor PSROIAlign_backward_cuda( +at::Tensor ps_roi_align_backward_cuda( const at::Tensor& grad, const at::Tensor& rois, const at::Tensor& channel_mapping, @@ -387,7 +387,7 @@ at::Tensor PSROIAlign_backward_cuda( at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2}, channel_mapping_t{channel_mapping, "channel_mapping", 3}; - at::CheckedFrom c = "PSROIAlign_backward_cuda"; + at::CheckedFrom c = "ps_roi_align_backward_cuda"; at::checkAllSameGPU(c, {grad_t, rois_t, channel_mapping_t}); at::checkAllSameType(c, {grad_t, rois_t}); @@ -415,8 +415,8 @@ at::Tensor PSROIAlign_backward_cuda( auto grad_ = grad.contiguous(), rois_ = rois.contiguous(); AT_DISPATCH_FLOATING_TYPES_AND_HALF( - grad.scalar_type(), "PSROIAlign_backward", [&] { - PSROIAlignBackwardCUDA<<>>( + grad.scalar_type(), "ps_roi_align_backward", [&] { + ps_roi_align_backward_kernel_impl<<>>( grad.numel(), grad_.data_ptr(), channel_mapping.data_ptr(), diff --git a/torchvision/csrc/cuda/vision_cuda.h b/torchvision/csrc/cuda/vision_cuda.h index b17f00d6acf..b4a37f19d0b 100644 --- a/torchvision/csrc/cuda/vision_cuda.h +++ b/torchvision/csrc/cuda/vision_cuda.h @@ -4,7 +4,7 @@ // TODO: Delete this file once all the methods are gone -VISION_API std::tuple PSROIAlign_forward_cuda( +VISION_API std::tuple ps_roi_align_forward_cuda( const at::Tensor& input, const at::Tensor& rois, double spatial_scale, @@ -12,7 +12,7 @@ VISION_API std::tuple PSROIAlign_forward_cuda( int64_t pooled_width, int64_t sampling_ratio); -VISION_API at::Tensor PSROIAlign_backward_cuda( +VISION_API at::Tensor ps_roi_align_backward_cuda( const at::Tensor& grad, const at::Tensor& rois, const at::Tensor& channel_mapping, diff --git a/torchvision/csrc/PSROIAlign.h b/torchvision/csrc/ps_roi_align.h similarity index 97% rename from torchvision/csrc/PSROIAlign.h rename to torchvision/csrc/ps_roi_align.h index 1e5dd17aabc..b8f1c4b7300 100644 --- a/torchvision/csrc/PSROIAlign.h +++ b/torchvision/csrc/ps_roi_align.h @@ -30,7 +30,7 @@ std::tuple ps_roi_align( } #if defined(WITH_CUDA) || defined(WITH_HIP) -std::tuple PSROIAlign_autocast( +std::tuple ps_roi_align_autocast( const at::Tensor& input, const at::Tensor& rois, double spatial_scale, @@ -186,7 +186,7 @@ class PSROIAlignBackwardFunction } }; -std::tuple PSROIAlign_autograd( +std::tuple ps_roi_align_autograd( const at::Tensor& input, const at::Tensor& rois, double spatial_scale, @@ -199,7 +199,7 @@ std::tuple PSROIAlign_autograd( return std::make_tuple(result[0], result[1]); } -at::Tensor PSROIAlign_backward_autograd( +at::Tensor ps_roi_align_backward_autograd( const at::Tensor& grad, const at::Tensor& rois, const at::Tensor& channel_mapping, diff --git a/torchvision/csrc/vision.cpp b/torchvision/csrc/vision.cpp index 2d4e2af0f53..c5c204aac2b 100644 --- a/torchvision/csrc/vision.cpp +++ b/torchvision/csrc/vision.cpp @@ -8,13 +8,13 @@ #include #endif -#include "PSROIAlign.h" #include "PSROIPool.h" #include "ROIAlign.h" #include "ROIPool.h" #include "deform_conv2d.h" #include "empty_tensor_op.h" #include "nms.h" +#include "ps_roi_align.h" // If we are in a Windows environment, we need to define // initialization functions for the _custom_ops extension @@ -65,8 +65,8 @@ TORCH_LIBRARY_IMPL(torchvision, CPU, m) { m.impl("deform_conv2d", deform_conv2d_forward_cpu); m.impl("_deform_conv2d_backward", deform_conv2d_backward_cpu); m.impl("nms", nms_cpu); - m.impl("ps_roi_align", PSROIAlign_forward_cpu); - m.impl("_ps_roi_align_backward", PSROIAlign_backward_cpu); + m.impl("ps_roi_align", ps_roi_align_forward_cpu); + m.impl("_ps_roi_align_backward", ps_roi_align_backward_cpu); m.impl("ps_roi_pool", PSROIPool_forward_cpu); m.impl("_ps_roi_pool_backward", PSROIPool_backward_cpu); m.impl("roi_align", ROIAlign_forward_cpu); @@ -81,8 +81,8 @@ TORCH_LIBRARY_IMPL(torchvision, CUDA, m) { m.impl("deform_conv2d", deform_conv2d_forward_cuda); m.impl("_deform_conv2d_backward", deform_conv2d_backward_cuda); m.impl("nms", nms_cuda); - m.impl("ps_roi_align", PSROIAlign_forward_cuda); - m.impl("_ps_roi_align_backward", PSROIAlign_backward_cuda); + m.impl("ps_roi_align", ps_roi_align_forward_cuda); + m.impl("_ps_roi_align_backward", ps_roi_align_backward_cuda); m.impl("ps_roi_pool", PSROIPool_forward_cuda); m.impl("_ps_roi_pool_backward", PSROIPool_backward_cuda); m.impl("roi_align", ROIAlign_forward_cuda); @@ -97,7 +97,7 @@ TORCH_LIBRARY_IMPL(torchvision, CUDA, m) { TORCH_LIBRARY_IMPL(torchvision, Autocast, m) { m.impl("deform_conv2d", deform_conv2d_autocast); m.impl("nms", nms_autocast); - m.impl("ps_roi_align", PSROIAlign_autocast); + m.impl("ps_roi_align", ps_roi_align_autocast); m.impl("ps_roi_pool", PSROIPool_autocast); m.impl("roi_align", ROIAlign_autocast); m.impl("roi_pool", ROIPool_autocast); @@ -107,8 +107,8 @@ TORCH_LIBRARY_IMPL(torchvision, Autocast, m) { TORCH_LIBRARY_IMPL(torchvision, Autograd, m) { m.impl("deform_conv2d", deform_conv2d_autograd); m.impl("_deform_conv2d_backward", deform_conv2d_backward_autograd); - m.impl("ps_roi_align", PSROIAlign_autograd); - m.impl("_ps_roi_align_backward", PSROIAlign_backward_autograd); + m.impl("ps_roi_align", ps_roi_align_autograd); + m.impl("_ps_roi_align_backward", ps_roi_align_backward_autograd); m.impl("ps_roi_pool", PSROIPool_autograd); m.impl("_ps_roi_pool_backward", PSROIPool_backward_autograd); m.impl("roi_align", ROIAlign_autograd); From b41a93c5b957a3afc9049ce1ec83e854bd06131b Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 1 Dec 2020 21:01:20 +0000 Subject: [PATCH 2/5] Adding all internal functions in anonymous namespaces. --- torchvision/csrc/cpu/ps_roi_align_cpu.cpp | 4 ++++ torchvision/csrc/cuda/ps_roi_align_cuda.cu | 4 ++++ torchvision/csrc/ps_roi_align.h | 4 ++++ 3 files changed, 12 insertions(+) diff --git a/torchvision/csrc/cpu/ps_roi_align_cpu.cpp b/torchvision/csrc/cpu/ps_roi_align_cpu.cpp index b25115701a2..2bfe6bfe132 100644 --- a/torchvision/csrc/cpu/ps_roi_align_cpu.cpp +++ b/torchvision/csrc/cpu/ps_roi_align_cpu.cpp @@ -2,6 +2,8 @@ #include #include +namespace { + template T bilinear_interpolate( const T* input, @@ -298,6 +300,8 @@ void ps_roi_align_backward_kernel_impl( } } +} // namespace + std::tuple ps_roi_align_forward_cpu( const at::Tensor& input, const at::Tensor& rois, diff --git a/torchvision/csrc/cuda/ps_roi_align_cuda.cu b/torchvision/csrc/cuda/ps_roi_align_cuda.cu index cfb4915bd76..64a15deacaa 100644 --- a/torchvision/csrc/cuda/ps_roi_align_cuda.cu +++ b/torchvision/csrc/cuda/ps_roi_align_cuda.cu @@ -7,6 +7,8 @@ #include "cuda_helpers.h" +namespace { + template __device__ T bilinear_interpolate( const T* input, @@ -292,6 +294,8 @@ __global__ void ps_roi_align_backward_kernel_impl( } } +} // namespace + std::tuple ps_roi_align_forward_cuda( const at::Tensor& input, const at::Tensor& rois, diff --git a/torchvision/csrc/ps_roi_align.h b/torchvision/csrc/ps_roi_align.h index b8f1c4b7300..12a322b9603 100644 --- a/torchvision/csrc/ps_roi_align.h +++ b/torchvision/csrc/ps_roi_align.h @@ -82,6 +82,8 @@ at::Tensor _ps_roi_align_backward( width); } +namespace { + class PSROIAlignFunction : public torch::autograd::Function { public: @@ -186,6 +188,8 @@ class PSROIAlignBackwardFunction } }; +} // namespace + std::tuple ps_roi_align_autograd( const at::Tensor& input, const at::Tensor& rois, From 84950abff34cd87a7349e847ae2d0de40f382dd1 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 1 Dec 2020 21:05:05 +0000 Subject: [PATCH 3/5] Renaming C++/CUDA kernel files and moving operator code from header to cpp file. --- .../csrc/cpu/{ps_roi_align_cpu.cpp => ps_roi_align_kernel.cpp} | 0 .../csrc/cuda/{ps_roi_align_cuda.cu => ps_roi_align_kernel.cu} | 0 torchvision/csrc/{ps_roi_align.h => ps_roi_align.cpp} | 0 3 files changed, 0 insertions(+), 0 deletions(-) rename torchvision/csrc/cpu/{ps_roi_align_cpu.cpp => ps_roi_align_kernel.cpp} (100%) rename torchvision/csrc/cuda/{ps_roi_align_cuda.cu => ps_roi_align_kernel.cu} (100%) rename torchvision/csrc/{ps_roi_align.h => ps_roi_align.cpp} (100%) diff --git a/torchvision/csrc/cpu/ps_roi_align_cpu.cpp b/torchvision/csrc/cpu/ps_roi_align_kernel.cpp similarity index 100% rename from torchvision/csrc/cpu/ps_roi_align_cpu.cpp rename to torchvision/csrc/cpu/ps_roi_align_kernel.cpp diff --git a/torchvision/csrc/cuda/ps_roi_align_cuda.cu b/torchvision/csrc/cuda/ps_roi_align_kernel.cu similarity index 100% rename from torchvision/csrc/cuda/ps_roi_align_cuda.cu rename to torchvision/csrc/cuda/ps_roi_align_kernel.cu diff --git a/torchvision/csrc/ps_roi_align.h b/torchvision/csrc/ps_roi_align.cpp similarity index 100% rename from torchvision/csrc/ps_roi_align.h rename to torchvision/csrc/ps_roi_align.cpp From 9f7d7888662366ed6e9f9d3df72848a981938abf Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 1 Dec 2020 21:15:36 +0000 Subject: [PATCH 4/5] Create foreach cpp file a separate header file with "public" functions. --- torchvision/csrc/cpu/ps_roi_align_kernel.h | 25 ++++++++ torchvision/csrc/cpu/vision_cpu.h | 21 ------- torchvision/csrc/cuda/ps_roi_align_kernel.h | 25 ++++++++ torchvision/csrc/cuda/vision_cuda.h | 21 ------- torchvision/csrc/ps_roi_align.cpp | 18 ++---- torchvision/csrc/ps_roi_align.h | 66 +++++++++++++++++++++ 6 files changed, 120 insertions(+), 56 deletions(-) create mode 100644 torchvision/csrc/cpu/ps_roi_align_kernel.h create mode 100644 torchvision/csrc/cuda/ps_roi_align_kernel.h create mode 100644 torchvision/csrc/ps_roi_align.h diff --git a/torchvision/csrc/cpu/ps_roi_align_kernel.h b/torchvision/csrc/cpu/ps_roi_align_kernel.h new file mode 100644 index 00000000000..86a3f9a8876 --- /dev/null +++ b/torchvision/csrc/cpu/ps_roi_align_kernel.h @@ -0,0 +1,25 @@ +#pragma once + +#include +#include "../macros.h" + +VISION_API std::tuple ps_roi_align_forward_cpu( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t sampling_ratio); + +VISION_API at::Tensor ps_roi_align_backward_cpu( + const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& channel_mapping, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t sampling_ratio, + int64_t batch_size, + int64_t channels, + int64_t height, + int64_t width); diff --git a/torchvision/csrc/cpu/vision_cpu.h b/torchvision/csrc/cpu/vision_cpu.h index db15c7172a4..22119b5e292 100644 --- a/torchvision/csrc/cpu/vision_cpu.h +++ b/torchvision/csrc/cpu/vision_cpu.h @@ -4,27 +4,6 @@ // TODO: Delete this file once all the methods are gone -VISION_API std::tuple ps_roi_align_forward_cpu( - const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t sampling_ratio); - -VISION_API at::Tensor ps_roi_align_backward_cpu( - const at::Tensor& grad, - const at::Tensor& rois, - const at::Tensor& channel_mapping, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t sampling_ratio, - int64_t batch_size, - int64_t channels, - int64_t height, - int64_t width); - VISION_API std::tuple PSROIPool_forward_cpu( const at::Tensor& input, const at::Tensor& rois, diff --git a/torchvision/csrc/cuda/ps_roi_align_kernel.h b/torchvision/csrc/cuda/ps_roi_align_kernel.h new file mode 100644 index 00000000000..45a300d6711 --- /dev/null +++ b/torchvision/csrc/cuda/ps_roi_align_kernel.h @@ -0,0 +1,25 @@ +#pragma once + +#include +#include "../macros.h" + +VISION_API std::tuple ps_roi_align_forward_cuda( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t sampling_ratio); + +VISION_API at::Tensor ps_roi_align_backward_cuda( + const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& channel_mapping, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t sampling_ratio, + int64_t batch_size, + int64_t channels, + int64_t height, + int64_t width); diff --git a/torchvision/csrc/cuda/vision_cuda.h b/torchvision/csrc/cuda/vision_cuda.h index b4a37f19d0b..c80386a8db1 100644 --- a/torchvision/csrc/cuda/vision_cuda.h +++ b/torchvision/csrc/cuda/vision_cuda.h @@ -4,27 +4,6 @@ // TODO: Delete this file once all the methods are gone -VISION_API std::tuple ps_roi_align_forward_cuda( - const at::Tensor& input, - const at::Tensor& rois, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t sampling_ratio); - -VISION_API at::Tensor ps_roi_align_backward_cuda( - const at::Tensor& grad, - const at::Tensor& rois, - const at::Tensor& channel_mapping, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t sampling_ratio, - int64_t batch_size, - int64_t channels, - int64_t height, - int64_t width); - VISION_API std::tuple PSROIPool_forward_cuda( const at::Tensor& input, const at::Tensor& rois, diff --git a/torchvision/csrc/ps_roi_align.cpp b/torchvision/csrc/ps_roi_align.cpp index 12a322b9603..0e1a30d6e63 100644 --- a/torchvision/csrc/ps_roi_align.cpp +++ b/torchvision/csrc/ps_roi_align.cpp @@ -1,20 +1,10 @@ -#pragma once +#include "ps_roi_align.h" +#include -#include "cpu/vision_cpu.h" - -#ifdef WITH_CUDA -#include "autocast.h" -#include "cuda/vision_cuda.h" -#endif -#ifdef WITH_HIP -#include "autocast.h" -#include "hip/vision_cuda.h" +#if defined(WITH_CUDA) || defined(WITH_HIP) +#include #endif -#include - -// TODO: put this stuff in torchvision namespace - std::tuple ps_roi_align( const at::Tensor& input, const at::Tensor& rois, diff --git a/torchvision/csrc/ps_roi_align.h b/torchvision/csrc/ps_roi_align.h new file mode 100644 index 00000000000..0f7ecea2f12 --- /dev/null +++ b/torchvision/csrc/ps_roi_align.h @@ -0,0 +1,66 @@ +#pragma once + +#include "cpu/ps_roi_align_kernel.h" + +#ifdef WITH_CUDA +#include "cuda/ps_roi_align_kernel.h" +#endif +#ifdef WITH_HIP +#include "hip/ps_roi_align_kernel.h" +#endif + +// C++ Forward +std::tuple ps_roi_align( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t sampling_ratio); + +// Autocast Forward +#if defined(WITH_CUDA) || defined(WITH_HIP) +std::tuple ps_roi_align_autocast( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t sampling_ratio); +#endif + +// C++ Backward +at::Tensor _ps_roi_align_backward( + const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& channel_mapping, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t sampling_ratio, + int64_t batch_size, + int64_t channels, + int64_t height, + int64_t width); + +// Autograd Forward and Backward +std::tuple ps_roi_align_autograd( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t sampling_ratio); + +at::Tensor ps_roi_align_backward_autograd( + const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& channel_mapping, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t sampling_ratio, + int64_t batch_size, + int64_t channels, + int64_t height, + int64_t width); From 864e0319b2ab14267cc3263eba804bf1a15c4cfd Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 1 Dec 2020 21:21:44 +0000 Subject: [PATCH 5/5] Removing unnecessary repeated includes. --- torchvision/csrc/cpu/ps_roi_align_kernel.cpp | 4 +--- torchvision/csrc/cuda/ps_roi_align_kernel.cu | 4 +--- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/torchvision/csrc/cpu/ps_roi_align_kernel.cpp b/torchvision/csrc/cpu/ps_roi_align_kernel.cpp index 2bfe6bfe132..a56fbe58e9a 100644 --- a/torchvision/csrc/cpu/ps_roi_align_kernel.cpp +++ b/torchvision/csrc/cpu/ps_roi_align_kernel.cpp @@ -1,6 +1,4 @@ -#include -#include -#include +#include "ps_roi_align_kernel.h" namespace { diff --git a/torchvision/csrc/cuda/ps_roi_align_kernel.cu b/torchvision/csrc/cuda/ps_roi_align_kernel.cu index 64a15deacaa..4ac0c28de4c 100644 --- a/torchvision/csrc/cuda/ps_roi_align_kernel.cu +++ b/torchvision/csrc/cuda/ps_roi_align_kernel.cu @@ -1,11 +1,9 @@ -#include -#include #include #include #include -#include #include "cuda_helpers.h" +#include "ps_roi_align_kernel.h" namespace {