Skip to content

Commit 5cb77a2

Browse files
authored
Static Analysis corrections on DeformConv (#2885)
* Convert to const reference and eliminate unnecessary bool casting. * Removing unnecessary namespace use.
1 parent cffac64 commit 5cb77a2

File tree

2 files changed

+6
-10
lines changed

2 files changed

+6
-10
lines changed

torchvision/csrc/cpu/DeformConv_cpu.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,6 @@
7474
#include <iostream>
7575
#include <tuple>
7676

77-
using namespace at;
78-
7977
const int kMaxParallelImgs = 32;
8078

8179
template <typename scalar_t>
@@ -597,7 +595,7 @@ static void deformable_col2im_coord_kernel(
597595
out_w;
598596

599597
const int offset_c = c - offset_grp * 2 * weight_h * weight_w;
600-
const int is_y_direction = offset_c % 2 == 0;
598+
const bool is_y_direction = offset_c % 2 == 0;
601599

602600
const int c_bound = c_per_offset_grp * weight_h * weight_w;
603601
for (int col_c = (offset_c / 2); col_c < c_bound; col_c += col_step) {
@@ -812,9 +810,9 @@ static std::tuple<at::Tensor, at::Tensor> deform_conv2d_backward_input_cpu(
812810

813811
static at::Tensor deform_conv2d_backward_parameters_cpu(
814812
at::Tensor input,
815-
at::Tensor weight,
813+
const at::Tensor& weight,
816814
at::Tensor offset,
817-
at::Tensor grad_out,
815+
const at::Tensor& grad_out,
818816
std::pair<int, int> stride,
819817
std::pair<int, int> pad,
820818
std::pair<int, int> dilation,

torchvision/csrc/cuda/DeformConv_cuda.cu

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,6 @@
7878
#include <iostream>
7979
#include <tuple>
8080

81-
using namespace at;
82-
8381
const unsigned int CUDA_NUM_THREADS = 1024;
8482
const int kMaxParallelImgs = 32;
8583

@@ -618,7 +616,7 @@ __global__ void deformable_col2im_coord_gpu_kernel(
618616
out_h * out_w;
619617
620618
const int offset_c = c - offset_grp * 2 * weight_h * weight_w;
621-
const int is_y_direction = offset_c % 2 == 0;
619+
const bool is_y_direction = offset_c % 2 == 0;
622620
623621
const int c_bound = c_per_offset_grp * weight_h * weight_w;
624622
for (int col_c = (offset_c / 2); col_c < c_bound; col_c += col_step) {
@@ -840,9 +838,9 @@ static std::tuple<at::Tensor, at::Tensor> deform_conv_backward_input_cuda(
840838
841839
static at::Tensor deform_conv_backward_parameters_cuda(
842840
at::Tensor input,
843-
at::Tensor weight,
841+
const at::Tensor& weight,
844842
at::Tensor offset,
845-
at::Tensor grad_out,
843+
const at::Tensor& grad_out,
846844
std::pair<int, int> stride,
847845
std::pair<int, int> pad,
848846
std::pair<int, int> dilation,

0 commit comments

Comments
 (0)