Skip to content

Commit 3c33f36

Browse files
authored
Per file C++ Operator registration (#3135)
* Moving deform_conv2d op registration. * Moving nms op registration. * Moving new_empty_tensor op registration. * Moving ps_roi_align op registration. * Moving ps_roi_pool op registration. * Moving roi_align op registration. * Moving roi_pool op registration. * Restoring headers for forward/backward and fixing styles. * Restoring the test hack on windows. * Stricter header inclusion.
1 parent 6cb4fc2 commit 3c33f36

40 files changed

+306
-804
lines changed

test/tracing/frcnn/test_frcnn_tracing.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
#include <ATen/ATen.h>
22
#include <torch/script.h>
33
#include <torch/torch.h>
4-
#include <torchvision/roi_align.h>
54
#include <torchvision/nms.h>
5+
#include <torchvision/roi_align.h>
66

77
#ifdef _WIN32
88
// Windows only
99
// This is necessary until operators are automatically registered on include
10-
static auto _nms = &vision::ops::nms_cpu;
10+
static auto _nms = &vision::ops::nms;
1111
#endif
1212

1313
int main() {

torchvision/csrc/cpu/deform_conv2d_kernel.cpp

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,8 @@
6666
// modified from
6767
// https://github.com/open-mmlab/mmdetection/blob/master/mmdet/ops/dcn/src/deform_conv_cuda.cpp
6868

69-
#include "deform_conv2d_kernel.h"
69+
#include <ATen/ATen.h>
70+
#include <torch/library.h>
7071

7172
namespace vision {
7273
namespace ops {
@@ -852,9 +853,7 @@ at::Tensor backward_gradient_parameters(
852853
return grad_weight;
853854
}
854855

855-
} // namespace
856-
857-
at::Tensor deform_conv2d_forward_cpu(
856+
at::Tensor deform_conv2d_forward_kernel(
858857
const at::Tensor& input,
859858
const at::Tensor& weight,
860859
const at::Tensor& offset,
@@ -1070,7 +1069,7 @@ at::Tensor deform_conv2d_forward_cpu(
10701069
}
10711070

10721071
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
1073-
deform_conv2d_backward_cpu(
1072+
deform_conv2d_backward_kernel(
10741073
const at::Tensor& grad_out,
10751074
const at::Tensor& input,
10761075
const at::Tensor& weight,
@@ -1141,5 +1140,12 @@ deform_conv2d_backward_cpu(
11411140
grad_input, grad_weight, grad_offset, grad_mask, grad_bias);
11421141
}
11431142

1143+
} // namespace
1144+
1145+
TORCH_LIBRARY_IMPL(torchvision, CPU, m) {
1146+
m.impl("deform_conv2d", deform_conv2d_forward_kernel);
1147+
m.impl("_deform_conv2d_backward", deform_conv2d_backward_kernel);
1148+
}
1149+
11441150
} // namespace ops
11451151
} // namespace vision

torchvision/csrc/cpu/deform_conv2d_kernel.h

Lines changed: 0 additions & 45 deletions
This file was deleted.

torchvision/csrc/cpu/nms_kernel.cpp

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
#include "nms_kernel.h"
1+
#include <ATen/ATen.h>
2+
#include <torch/library.h>
23

34
namespace vision {
45
namespace ops {
@@ -74,9 +75,7 @@ at::Tensor nms_kernel_impl(
7475
return keep_t.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep);
7576
}
7677

77-
} // namespace
78-
79-
at::Tensor nms_cpu(
78+
at::Tensor nms_kernel(
8079
const at::Tensor& dets,
8180
const at::Tensor& scores,
8281
double iou_threshold) {
@@ -101,11 +100,17 @@ at::Tensor nms_cpu(
101100

102101
auto result = at::empty({0}, dets.options());
103102

104-
AT_DISPATCH_FLOATING_TYPES(dets.scalar_type(), "nms_cpu", [&] {
103+
AT_DISPATCH_FLOATING_TYPES(dets.scalar_type(), "nms_kernel", [&] {
105104
result = nms_kernel_impl<scalar_t>(dets, scores, iou_threshold);
106105
});
107106
return result;
108107
}
109108

109+
} // namespace
110+
111+
TORCH_LIBRARY_IMPL(torchvision, CPU, m) {
112+
m.impl("nms", nms_kernel);
113+
}
114+
110115
} // namespace ops
111116
} // namespace vision

torchvision/csrc/cpu/nms_kernel.h

Lines changed: 0 additions & 15 deletions
This file was deleted.

torchvision/csrc/cpu/ps_roi_align_kernel.cpp

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
#include "ps_roi_align_kernel.h"
1+
#include <ATen/ATen.h>
2+
#include <torch/library.h>
23

34
namespace vision {
45
namespace ops {
@@ -301,9 +302,7 @@ void ps_roi_align_backward_kernel_impl(
301302
}
302303
}
303304

304-
} // namespace
305-
306-
std::tuple<at::Tensor, at::Tensor> ps_roi_align_forward_cpu(
305+
std::tuple<at::Tensor, at::Tensor> ps_roi_align_forward_kernel(
307306
const at::Tensor& input,
308307
const at::Tensor& rois,
309308
double spatial_scale,
@@ -318,7 +317,7 @@ std::tuple<at::Tensor, at::Tensor> ps_roi_align_forward_cpu(
318317

319318
at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2};
320319

321-
at::CheckedFrom c = "ps_roi_align_forward_cpu";
320+
at::CheckedFrom c = "ps_roi_align_forward_kernel";
322321
at::checkAllSameType(c, {input_t, rois_t});
323322

324323
int num_rois = rois.size(0);
@@ -343,7 +342,7 @@ std::tuple<at::Tensor, at::Tensor> ps_roi_align_forward_cpu(
343342

344343
auto input_ = input.contiguous(), rois_ = rois.contiguous();
345344
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
346-
input.scalar_type(), "ps_roi_align_forward_cpu", [&] {
345+
input.scalar_type(), "ps_roi_align_forward_kernel", [&] {
347346
ps_roi_align_forward_kernel_impl<scalar_t>(
348347
output_size,
349348
input_.data_ptr<scalar_t>(),
@@ -362,7 +361,7 @@ std::tuple<at::Tensor, at::Tensor> ps_roi_align_forward_cpu(
362361
return std::make_tuple(output, channel_mapping);
363362
}
364363

365-
at::Tensor ps_roi_align_backward_cpu(
364+
at::Tensor ps_roi_align_backward_kernel(
366365
const at::Tensor& grad,
367366
const at::Tensor& rois,
368367
const at::Tensor& channel_mapping,
@@ -384,7 +383,7 @@ at::Tensor ps_roi_align_backward_cpu(
384383
at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2},
385384
channel_mapping_t{channel_mapping, "channel_mapping", 3};
386385

387-
at::CheckedFrom c = "ps_roi_align_backward_cpu";
386+
at::CheckedFrom c = "ps_roi_align_backward_kernel";
388387
at::checkAllSameType(c, {grad_t, rois_t});
389388

390389
auto num_rois = rois.size(0);
@@ -400,7 +399,7 @@ at::Tensor ps_roi_align_backward_cpu(
400399

401400
auto grad_ = grad.contiguous(), rois_ = rois.contiguous();
402401
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
403-
grad.scalar_type(), "ps_roi_align_backward_cpu", [&] {
402+
grad.scalar_type(), "ps_roi_align_backward_kernel", [&] {
404403
ps_roi_align_backward_kernel_impl<scalar_t>(
405404
grad.numel(),
406405
grad_.data_ptr<scalar_t>(),
@@ -420,5 +419,12 @@ at::Tensor ps_roi_align_backward_cpu(
420419
return grad_input;
421420
}
422421

422+
} // namespace
423+
424+
TORCH_LIBRARY_IMPL(torchvision, CPU, m) {
425+
m.impl("ps_roi_align", ps_roi_align_forward_kernel);
426+
m.impl("_ps_roi_align_backward", ps_roi_align_backward_kernel);
427+
}
428+
423429
} // namespace ops
424430
} // namespace vision

torchvision/csrc/cpu/ps_roi_align_kernel.h

Lines changed: 0 additions & 31 deletions
This file was deleted.

torchvision/csrc/cpu/ps_roi_pool_kernel.cpp

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
#include "ps_roi_pool_kernel.h"
1+
#include <ATen/ATen.h>
2+
#include <torch/library.h>
23

34
namespace vision {
45
namespace ops {
@@ -145,9 +146,7 @@ void ps_roi_pool_backward_kernel_impl(
145146
}
146147
}
147148

148-
} // namespace
149-
150-
std::tuple<at::Tensor, at::Tensor> ps_roi_pool_forward_cpu(
149+
std::tuple<at::Tensor, at::Tensor> ps_roi_pool_forward_kernel(
151150
const at::Tensor& input,
152151
const at::Tensor& rois,
153152
double spatial_scale,
@@ -161,7 +160,7 @@ std::tuple<at::Tensor, at::Tensor> ps_roi_pool_forward_cpu(
161160

162161
at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2};
163162

164-
at::CheckedFrom c = "ps_roi_pool_forward_cpu";
163+
at::CheckedFrom c = "ps_roi_pool_forward_kernel";
165164
at::checkAllSameType(c, {input_t, rois_t});
166165

167166
int num_rois = rois.size(0);
@@ -186,7 +185,7 @@ std::tuple<at::Tensor, at::Tensor> ps_roi_pool_forward_cpu(
186185

187186
auto input_ = input.contiguous(), rois_ = rois.contiguous();
188187
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
189-
input.scalar_type(), "ps_roi_pool_forward_cpu", [&] {
188+
input.scalar_type(), "ps_roi_pool_forward_kernel", [&] {
190189
ps_roi_pool_forward_kernel_impl<scalar_t>(
191190
input_.data_ptr<scalar_t>(),
192191
spatial_scale,
@@ -204,7 +203,7 @@ std::tuple<at::Tensor, at::Tensor> ps_roi_pool_forward_cpu(
204203
return std::make_tuple(output, channel_mapping);
205204
}
206205

207-
at::Tensor ps_roi_pool_backward_cpu(
206+
at::Tensor ps_roi_pool_backward_kernel(
208207
const at::Tensor& grad,
209208
const at::Tensor& rois,
210209
const at::Tensor& channel_mapping,
@@ -225,7 +224,7 @@ at::Tensor ps_roi_pool_backward_cpu(
225224
at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2},
226225
channel_mapping_t{channel_mapping, "channel_mapping", 3};
227226

228-
at::CheckedFrom c = "ps_roi_pool_backward_cpu";
227+
at::CheckedFrom c = "ps_roi_pool_backward_kernel";
229228
at::checkAllSameType(c, {grad_t, rois_t});
230229

231230
auto num_rois = rois.size(0);
@@ -241,7 +240,7 @@ at::Tensor ps_roi_pool_backward_cpu(
241240

242241
auto grad_ = grad.contiguous(), rois_ = rois.contiguous();
243242
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
244-
grad.scalar_type(), "ps_roi_pool_backward_cpu", [&] {
243+
grad.scalar_type(), "ps_roi_pool_backward_kernel", [&] {
245244
ps_roi_pool_backward_kernel_impl<scalar_t>(
246245
grad_.data_ptr<scalar_t>(),
247246
channel_mapping.data_ptr<int>(),
@@ -259,5 +258,12 @@ at::Tensor ps_roi_pool_backward_cpu(
259258
return grad_input;
260259
}
261260

261+
} // namespace
262+
263+
TORCH_LIBRARY_IMPL(torchvision, CPU, m) {
264+
m.impl("ps_roi_pool", ps_roi_pool_forward_kernel);
265+
m.impl("_ps_roi_pool_backward", ps_roi_pool_backward_kernel);
266+
}
267+
262268
} // namespace ops
263269
} // namespace vision

torchvision/csrc/cpu/ps_roi_pool_kernel.h

Lines changed: 0 additions & 29 deletions
This file was deleted.

0 commit comments

Comments
 (0)