Skip to content

Encapsulate and standardize ps_roi_align #3082

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#include <ATen/ATen.h>
#include <ATen/TensorUtils.h>
#include <TH/TH.h>
#include "ps_roi_align_kernel.h"

namespace {

template <typename T>
T bilinear_interpolate(
Expand Down Expand Up @@ -57,7 +57,7 @@ T bilinear_interpolate(
}

template <typename T>
void PSROIAlignForwardCPU(
void ps_roi_align_forward_kernel_impl(
int nthreads,
const T* input,
const T spatial_scale,
Expand Down Expand Up @@ -202,7 +202,7 @@ inline void add(T* address, const T& val) {
}

template <typename T>
void PSROIAlignBackwardCPU(
void ps_roi_align_backward_kernel_impl(
int nthreads,
const T* grad_output,
const int* channel_mapping,
Expand Down Expand Up @@ -298,7 +298,9 @@ void PSROIAlignBackwardCPU(
}
}

std::tuple<at::Tensor, at::Tensor> PSROIAlign_forward_cpu(
} // namespace

std::tuple<at::Tensor, at::Tensor> ps_roi_align_forward_cpu(
const at::Tensor& input,
const at::Tensor& rois,
double spatial_scale,
Expand All @@ -313,7 +315,7 @@ std::tuple<at::Tensor, at::Tensor> PSROIAlign_forward_cpu(

at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2};

at::CheckedFrom c = "PSROIAlign_forward_cpu";
at::CheckedFrom c = "ps_roi_align_forward_cpu";
at::checkAllSameType(c, {input_t, rois_t});

int num_rois = rois.size(0);
Expand All @@ -338,8 +340,8 @@ std::tuple<at::Tensor, at::Tensor> PSROIAlign_forward_cpu(

auto input_ = input.contiguous(), rois_ = rois.contiguous();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.scalar_type(), "PSROIAlign_forward", [&] {
PSROIAlignForwardCPU<scalar_t>(
input.scalar_type(), "ps_roi_align_forward", [&] {
ps_roi_align_forward_kernel_impl<scalar_t>(
output_size,
input_.data_ptr<scalar_t>(),
spatial_scale,
Expand All @@ -357,7 +359,7 @@ std::tuple<at::Tensor, at::Tensor> PSROIAlign_forward_cpu(
return std::make_tuple(output, channel_mapping);
}

at::Tensor PSROIAlign_backward_cpu(
at::Tensor ps_roi_align_backward_cpu(
const at::Tensor& grad,
const at::Tensor& rois,
const at::Tensor& channel_mapping,
Expand All @@ -379,7 +381,7 @@ at::Tensor PSROIAlign_backward_cpu(
at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2},
channel_mapping_t{channel_mapping, "channel_mapping", 3};

at::CheckedFrom c = "PSROIAlign_backward_cpu";
at::CheckedFrom c = "ps_roi_align_backward_cpu";
at::checkAllSameType(c, {grad_t, rois_t});

auto num_rois = rois.size(0);
Expand All @@ -395,8 +397,8 @@ at::Tensor PSROIAlign_backward_cpu(

auto grad_ = grad.contiguous(), rois_ = rois.contiguous();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad.scalar_type(), "PSROIAlign_backward", [&] {
PSROIAlignBackwardCPU<scalar_t>(
grad.scalar_type(), "ps_roi_align_backward", [&] {
ps_roi_align_backward_kernel_impl<scalar_t>(
grad.numel(),
grad_.data_ptr<scalar_t>(),
channel_mapping.data_ptr<int>(),
Expand Down
25 changes: 25 additions & 0 deletions torchvision/csrc/cpu/ps_roi_align_kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#pragma once

#include <ATen/ATen.h>
#include "../macros.h"

VISION_API std::tuple<at::Tensor, at::Tensor> ps_roi_align_forward_cpu(
const at::Tensor& input,
const at::Tensor& rois,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t sampling_ratio);

VISION_API at::Tensor ps_roi_align_backward_cpu(
const at::Tensor& grad,
const at::Tensor& rois,
const at::Tensor& channel_mapping,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t sampling_ratio,
int64_t batch_size,
int64_t channels,
int64_t height,
int64_t width);
21 changes: 0 additions & 21 deletions torchvision/csrc/cpu/vision_cpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,6 @@

// TODO: Delete this file once all the methods are gone

VISION_API std::tuple<at::Tensor, at::Tensor> PSROIAlign_forward_cpu(
const at::Tensor& input,
const at::Tensor& rois,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t sampling_ratio);

VISION_API at::Tensor PSROIAlign_backward_cpu(
const at::Tensor& grad,
const at::Tensor& rois,
const at::Tensor& channel_mapping,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t sampling_ratio,
int64_t batch_size,
int64_t channels,
int64_t height,
int64_t width);

VISION_API std::tuple<at::Tensor, at::Tensor> PSROIPool_forward_cpu(
const at::Tensor& input,
const at::Tensor& rois,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
#include <ATen/ATen.h>
#include <ATen/TensorUtils.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <THC/THCAtomics.cuh>
#include <stdio.h>

#include "cuda_helpers.h"
#include "ps_roi_align_kernel.h"

namespace {

template <typename T>
__device__ T bilinear_interpolate(
Expand Down Expand Up @@ -62,7 +62,7 @@ __device__ T bilinear_interpolate(
}

template <typename T>
__global__ void PSROIAlignForwardCUDA(
__global__ void ps_roi_align_forward_kernel_impl(
int nthreads,
const T* input,
const T spatial_scale,
Expand Down Expand Up @@ -195,7 +195,7 @@ __device__ void bilinear_interpolate_gradient(
}

template <typename T>
__global__ void PSROIAlignBackwardCUDA(
__global__ void ps_roi_align_backward_kernel_impl(
int nthreads,
const T* grad_output,
const int* channel_mapping,
Expand Down Expand Up @@ -292,7 +292,9 @@ __global__ void PSROIAlignBackwardCUDA(
}
}

std::tuple<at::Tensor, at::Tensor> PSROIAlign_forward_cuda(
} // namespace

std::tuple<at::Tensor, at::Tensor> ps_roi_align_forward_cuda(
const at::Tensor& input,
const at::Tensor& rois,
double spatial_scale,
Expand All @@ -307,7 +309,7 @@ std::tuple<at::Tensor, at::Tensor> PSROIAlign_forward_cuda(

at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2};

at::CheckedFrom c = "PSROIAlign_forward_cuda";
at::CheckedFrom c = "ps_roi_align_forward_cuda";
at::checkAllSameGPU(c, {input_t, rois_t});
at::checkAllSameType(c, {input_t, rois_t});

Expand Down Expand Up @@ -344,8 +346,8 @@ std::tuple<at::Tensor, at::Tensor> PSROIAlign_forward_cuda(
auto input_ = input.contiguous(),
rois_ = rois.contiguous();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.scalar_type(), "PSROIAlign_forward", [&] {
PSROIAlignForwardCUDA<scalar_t><<<grid, block, 0, stream>>>(
input.scalar_type(), "ps_roi_align_forward", [&] {
ps_roi_align_forward_kernel_impl<scalar_t><<<grid, block, 0, stream>>>(
output_size,
input_.data_ptr<scalar_t>(),
spatial_scale,
Expand All @@ -365,7 +367,7 @@ std::tuple<at::Tensor, at::Tensor> PSROIAlign_forward_cuda(
return std::make_tuple(output, channel_mapping);
}

at::Tensor PSROIAlign_backward_cuda(
at::Tensor ps_roi_align_backward_cuda(
const at::Tensor& grad,
const at::Tensor& rois,
const at::Tensor& channel_mapping,
Expand All @@ -387,7 +389,7 @@ at::Tensor PSROIAlign_backward_cuda(
at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2},
channel_mapping_t{channel_mapping, "channel_mapping", 3};

at::CheckedFrom c = "PSROIAlign_backward_cuda";
at::CheckedFrom c = "ps_roi_align_backward_cuda";
at::checkAllSameGPU(c, {grad_t, rois_t, channel_mapping_t});
at::checkAllSameType(c, {grad_t, rois_t});

Expand Down Expand Up @@ -415,8 +417,8 @@ at::Tensor PSROIAlign_backward_cuda(
auto grad_ = grad.contiguous(),
rois_ = rois.contiguous();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad.scalar_type(), "PSROIAlign_backward", [&] {
PSROIAlignBackwardCUDA<scalar_t><<<grid, block, 0, stream>>>(
grad.scalar_type(), "ps_roi_align_backward", [&] {
ps_roi_align_backward_kernel_impl<scalar_t><<<grid, block, 0, stream>>>(
grad.numel(),
grad_.data_ptr<scalar_t>(),
channel_mapping.data_ptr<int>(),
Expand Down
25 changes: 25 additions & 0 deletions torchvision/csrc/cuda/ps_roi_align_kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#pragma once

#include <ATen/ATen.h>
#include "../macros.h"

VISION_API std::tuple<at::Tensor, at::Tensor> ps_roi_align_forward_cuda(
const at::Tensor& input,
const at::Tensor& rois,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t sampling_ratio);

VISION_API at::Tensor ps_roi_align_backward_cuda(
const at::Tensor& grad,
const at::Tensor& rois,
const at::Tensor& channel_mapping,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t sampling_ratio,
int64_t batch_size,
int64_t channels,
int64_t height,
int64_t width);
21 changes: 0 additions & 21 deletions torchvision/csrc/cuda/vision_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,6 @@

// TODO: Delete this file once all the methods are gone

VISION_API std::tuple<at::Tensor, at::Tensor> PSROIAlign_forward_cuda(
const at::Tensor& input,
const at::Tensor& rois,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t sampling_ratio);

VISION_API at::Tensor PSROIAlign_backward_cuda(
const at::Tensor& grad,
const at::Tensor& rois,
const at::Tensor& channel_mapping,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t sampling_ratio,
int64_t batch_size,
int64_t channels,
int64_t height,
int64_t width);

VISION_API std::tuple<at::Tensor, at::Tensor> PSROIPool_forward_cuda(
const at::Tensor& input,
const at::Tensor& rois,
Expand Down
28 changes: 11 additions & 17 deletions torchvision/csrc/PSROIAlign.h → torchvision/csrc/ps_roi_align.cpp
Original file line number Diff line number Diff line change
@@ -1,20 +1,10 @@
#pragma once
#include "ps_roi_align.h"
#include <torch/extension.h>

#include "cpu/vision_cpu.h"

#ifdef WITH_CUDA
#include "autocast.h"
#include "cuda/vision_cuda.h"
#endif
#ifdef WITH_HIP
#include "autocast.h"
#include "hip/vision_cuda.h"
#if defined(WITH_CUDA) || defined(WITH_HIP)
#include <ATen/autocast_mode.h>
#endif

#include <iostream>

// TODO: put this stuff in torchvision namespace

std::tuple<at::Tensor, at::Tensor> ps_roi_align(
const at::Tensor& input,
const at::Tensor& rois,
Expand All @@ -30,7 +20,7 @@ std::tuple<at::Tensor, at::Tensor> ps_roi_align(
}

#if defined(WITH_CUDA) || defined(WITH_HIP)
std::tuple<at::Tensor, at::Tensor> PSROIAlign_autocast(
std::tuple<at::Tensor, at::Tensor> ps_roi_align_autocast(
const at::Tensor& input,
const at::Tensor& rois,
double spatial_scale,
Expand Down Expand Up @@ -82,6 +72,8 @@ at::Tensor _ps_roi_align_backward(
width);
}

namespace {

class PSROIAlignFunction
: public torch::autograd::Function<PSROIAlignFunction> {
public:
Expand Down Expand Up @@ -186,7 +178,9 @@ class PSROIAlignBackwardFunction
}
};

std::tuple<at::Tensor, at::Tensor> PSROIAlign_autograd(
} // namespace

std::tuple<at::Tensor, at::Tensor> ps_roi_align_autograd(
const at::Tensor& input,
const at::Tensor& rois,
double spatial_scale,
Expand All @@ -199,7 +193,7 @@ std::tuple<at::Tensor, at::Tensor> PSROIAlign_autograd(
return std::make_tuple(result[0], result[1]);
}

at::Tensor PSROIAlign_backward_autograd(
at::Tensor ps_roi_align_backward_autograd(
const at::Tensor& grad,
const at::Tensor& rois,
const at::Tensor& channel_mapping,
Expand Down
Loading