Skip to content

Commit 02a1918

Browse files
NicolasHugfmassa
andauthored
Minor cleanup of roi_align_forward_kernel_impl (#3619)
* minor clean up * do same for ps_roialign Co-authored-by: Francisco Massa <fvsmassa@gmail.com>
1 parent 591c899 commit 02a1918

File tree

2 files changed

+5
-10
lines changed

2 files changed

+5
-10
lines changed

torchvision/csrc/ops/cpu/ps_roi_align_kernel.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ T bilinear_interpolate(
6262

6363
template <typename T>
6464
void ps_roi_align_forward_kernel_impl(
65-
int nthreads,
65+
int num_rois,
6666
const T* input,
6767
const T spatial_scale,
6868
int channels,
@@ -75,7 +75,6 @@ void ps_roi_align_forward_kernel_impl(
7575
int channels_out,
7676
T* output,
7777
int* channel_mapping) {
78-
int num_rois = nthreads / channels_out / pooled_width / pooled_height;
7978
for (int n = 0; n < num_rois; n++) {
8079
// [start, end) interval for spatial sampling
8180
const T* offset_rois = rois + n * 5;
@@ -335,16 +334,15 @@ std::tuple<at::Tensor, at::Tensor> ps_roi_align_forward_kernel(
335334
auto channel_mapping =
336335
at::zeros(output.sizes(), input.options().dtype(at::kInt));
337336

338-
auto output_size = output.numel();
339-
if (output_size == 0) {
337+
if (output.numel() == 0) {
340338
return std::make_tuple(output, channel_mapping);
341339
}
342340

343341
auto input_ = input.contiguous(), rois_ = rois.contiguous();
344342
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
345343
input.scalar_type(), "ps_roi_align_forward_kernel", [&] {
346344
ps_roi_align_forward_kernel_impl<scalar_t>(
347-
output_size,
345+
num_rois,
348346
input_.data_ptr<scalar_t>(),
349347
spatial_scale,
350348
channels,

torchvision/csrc/ops/cpu/roi_align_kernel.cpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ void pre_calc_for_bilinear_interpolate(
117117

118118
template <typename T>
119119
void roi_align_forward_kernel_impl(
120-
int nthreads,
120+
int n_rois,
121121
const T* input,
122122
const T& spatial_scale,
123123
int channels,
@@ -129,7 +129,6 @@ void roi_align_forward_kernel_impl(
129129
bool aligned,
130130
const T* rois,
131131
T* output) {
132-
int n_rois = nthreads / channels / pooled_width / pooled_height;
133132
// (n, c, ph, pw) is an element in the pooled output
134133
// can be parallelized using omp
135134
// #pragma omp parallel for num_threads(32)
@@ -414,16 +413,14 @@ at::Tensor roi_align_forward_kernel(
414413
at::Tensor output = at::zeros(
415414
{num_rois, channels, pooled_height, pooled_width}, input.options());
416415

417-
auto output_size = num_rois * pooled_height * pooled_width * channels;
418-
419416
if (output.numel() == 0)
420417
return output;
421418

422419
auto input_ = input.contiguous(), rois_ = rois.contiguous();
423420
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
424421
input.scalar_type(), "roi_align_forward_kernel", [&] {
425422
roi_align_forward_kernel_impl<scalar_t>(
426-
output_size,
423+
num_rois,
427424
input_.data_ptr<scalar_t>(),
428425
spatial_scale,
429426
channels,

0 commit comments

Comments
 (0)