-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Port roi_align to actually use dispatcher #2366
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,8 +9,9 @@ | |
#include "hip/vision_cuda.h" | ||
#endif | ||
|
||
// Interface for Python | ||
at::Tensor ROIAlign_forward( | ||
// TODO: put this stuff in torchvision namespace | ||
|
||
at::Tensor roi_align( | ||
const at::Tensor& input, // Input feature map. | ||
const at::Tensor& rois, // List of ROIs to pool over. | ||
const double spatial_scale, // The scale of the image features. ROIs will be | ||
|
@@ -21,21 +22,10 @@ at::Tensor ROIAlign_forward( | |
const bool aligned) // The flag for pixel shift | ||
// along each axis. | ||
{ | ||
if (input.is_cuda()) { | ||
#if defined(WITH_CUDA) || defined(WITH_HIP) | ||
return ROIAlign_forward_cuda( | ||
input, | ||
rois, | ||
spatial_scale, | ||
pooled_height, | ||
pooled_width, | ||
sampling_ratio, | ||
aligned); | ||
#else | ||
AT_ERROR("Not compiled with GPU support"); | ||
#endif | ||
} | ||
return ROIAlign_forward_cpu( | ||
static auto op = c10::Dispatcher::singleton() | ||
.findSchemaOrThrow("torchvision::roi_align", "") | ||
.typed<decltype(roi_align)>(); | ||
return op.call( | ||
input, | ||
rois, | ||
spatial_scale, | ||
|
@@ -45,37 +35,23 @@ at::Tensor ROIAlign_forward( | |
aligned); | ||
} | ||
|
||
at::Tensor ROIAlign_backward( | ||
at::Tensor _roi_align_backward( | ||
const at::Tensor& grad, | ||
const at::Tensor& rois, | ||
const float spatial_scale, | ||
const int pooled_height, | ||
const int pooled_width, | ||
const int batch_size, | ||
const int channels, | ||
const int height, | ||
const int width, | ||
const int sampling_ratio, | ||
const double spatial_scale, | ||
const int64_t pooled_height, | ||
const int64_t pooled_width, | ||
const int64_t batch_size, | ||
const int64_t channels, | ||
const int64_t height, | ||
const int64_t width, | ||
const int64_t sampling_ratio, | ||
const bool aligned) { | ||
if (grad.is_cuda()) { | ||
#if defined(WITH_CUDA) || defined(WITH_HIP) | ||
return ROIAlign_backward_cuda( | ||
grad, | ||
rois, | ||
spatial_scale, | ||
pooled_height, | ||
pooled_width, | ||
batch_size, | ||
channels, | ||
height, | ||
width, | ||
sampling_ratio, | ||
aligned); | ||
#else | ||
AT_ERROR("Not compiled with GPU support"); | ||
#endif | ||
} | ||
return ROIAlign_backward_cpu( | ||
static auto op = | ||
c10::Dispatcher::singleton() | ||
.findSchemaOrThrow("torchvision::_roi_align_backward", "") | ||
.typed<decltype(_roi_align_backward)>(); | ||
return op.call( | ||
grad, | ||
rois, | ||
spatial_scale, | ||
|
@@ -107,7 +83,8 @@ class ROIAlignFunction : public torch::autograd::Function<ROIAlignFunction> { | |
ctx->saved_data["aligned"] = aligned; | ||
ctx->saved_data["input_shape"] = input.sizes(); | ||
ctx->save_for_backward({rois}); | ||
auto result = ROIAlign_forward( | ||
at::AutoNonVariableTypeMode g; | ||
auto result = roi_align( | ||
input, | ||
rois, | ||
spatial_scale, | ||
|
@@ -125,7 +102,7 @@ class ROIAlignFunction : public torch::autograd::Function<ROIAlignFunction> { | |
auto saved = ctx->get_saved_variables(); | ||
auto rois = saved[0]; | ||
auto input_shape = ctx->saved_data["input_shape"].toIntList(); | ||
auto grad_in = ROIAlign_backward( | ||
auto grad_in = _roi_align_backward( | ||
grad_output[0], | ||
rois, | ||
ctx->saved_data["spatial_scale"].toDouble(), | ||
|
@@ -147,7 +124,47 @@ class ROIAlignFunction : public torch::autograd::Function<ROIAlignFunction> { | |
} | ||
}; | ||
|
||
at::Tensor roi_align( | ||
// TODO: There should be an easier way to do this | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can't we register a fallback kernel that raises an error on double-backward? So that we don't have to implement a dummy double-backwards kernel for all the ops. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This would indeed be the right thing to do in core library. There will be some BC consequences though |
||
class ROIAlignBackwardFunction | ||
: public torch::autograd::Function<ROIAlignBackwardFunction> { | ||
public: | ||
static torch::autograd::variable_list forward( | ||
torch::autograd::AutogradContext* ctx, | ||
torch::autograd::Variable grad, | ||
torch::autograd::Variable rois, | ||
const double spatial_scale, | ||
const int64_t pooled_height, | ||
const int64_t pooled_width, | ||
const int64_t batch_size, | ||
const int64_t channels, | ||
const int64_t height, | ||
const int64_t width, | ||
const int64_t sampling_ratio, | ||
const bool aligned) { | ||
at::AutoNonVariableTypeMode g; | ||
auto result = _roi_align_backward( | ||
grad, | ||
rois, | ||
spatial_scale, | ||
pooled_height, | ||
pooled_width, | ||
batch_size, | ||
channels, | ||
height, | ||
width, | ||
sampling_ratio, | ||
aligned); | ||
return {result}; | ||
} | ||
|
||
static torch::autograd::variable_list backward( | ||
torch::autograd::AutogradContext* ctx, | ||
torch::autograd::variable_list grad_output) { | ||
TORCH_CHECK(0, "double backwards on roi_align not supported"); | ||
} | ||
}; | ||
|
||
at::Tensor ROIAlign_autograd( | ||
const at::Tensor& input, | ||
const at::Tensor& rois, | ||
const double spatial_scale, | ||
|
@@ -164,3 +181,29 @@ at::Tensor roi_align( | |
sampling_ratio, | ||
aligned)[0]; | ||
} | ||
|
||
at::Tensor ROIAlign_backward_autograd( | ||
const at::Tensor& grad, | ||
const at::Tensor& rois, | ||
const double spatial_scale, | ||
const int64_t pooled_height, | ||
const int64_t pooled_width, | ||
const int64_t batch_size, | ||
const int64_t channels, | ||
const int64_t height, | ||
const int64_t width, | ||
const int64_t sampling_ratio, | ||
const bool aligned) { | ||
return ROIAlignBackwardFunction::apply( | ||
grad, | ||
rois, | ||
spatial_scale, | ||
pooled_height, | ||
pooled_width, | ||
batch_size, | ||
channels, | ||
height, | ||
width, | ||
sampling_ratio, | ||
aligned)[0]; | ||
} |
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why an explicit autograd-disabling guard here? Does
torch::autograd::Function
not disable autograd automatically aroundforward
?The pre-PR forward doesn't use an explicit guard, and afaik Python-side
torch.autograd.Function
does disable autograd around itsforward
method. Both of these lead me to expecttorch::autograd::Function
also disables autograd aroundforward
. (If it doesn't, I think it should. Aligning its behavior with the Python version makes sense to me. But that would be a Pytorch-side change.)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pytorch/pytorch#40736