From 4a4cec0496f843cb6301121d874509107f0fe609 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 20 Aug 2021 11:26:04 +0100 Subject: [PATCH 01/10] Adding operator. --- torchvision/ops/__init__.py | 3 +- torchvision/ops/stochastic_depth.py | 56 +++++++++++++++++++++++++++++ 2 files changed, 58 insertions(+), 1 deletion(-) create mode 100644 torchvision/ops/stochastic_depth.py diff --git a/torchvision/ops/__init__.py b/torchvision/ops/__init__.py index 0ec189dbc2a..606c27abcbe 100644 --- a/torchvision/ops/__init__.py +++ b/torchvision/ops/__init__.py @@ -8,6 +8,7 @@ from .poolers import MultiScaleRoIAlign from .feature_pyramid_network import FeaturePyramidNetwork from .focal_loss import sigmoid_focal_loss +from .stochastic_depth import stochastic_depth, StochasticDepth from ._register_onnx_ops import _register_custom_op @@ -20,5 +21,5 @@ 'box_area', 'box_iou', 'generalized_box_iou', 'roi_align', 'RoIAlign', 'roi_pool', 'RoIPool', 'ps_roi_align', 'PSRoIAlign', 'ps_roi_pool', 'PSRoIPool', 'MultiScaleRoIAlign', 'FeaturePyramidNetwork', - 'sigmoid_focal_loss' + 'sigmoid_focal_loss', 'stochastic_depth', 'StochasticDepth' ] diff --git a/torchvision/ops/stochastic_depth.py b/torchvision/ops/stochastic_depth.py new file mode 100644 index 00000000000..de30abaa36b --- /dev/null +++ b/torchvision/ops/stochastic_depth.py @@ -0,0 +1,56 @@ +import torch +from torch import nn, Tensor + + +def stochastic_depth(input: Tensor, mode: str, p: float, training: bool = True) -> Tensor: + """ + Implements the Stochastic Depth from `"Deep Networks with Stochastic Depth" + `_ used for randomly dropping residual + branches of residual architectures. + + Args: + input (Tensor[N, ...]): The input tensor or arbitrary dimensions with the first one + being its batch i.e. a batch with ``N`` rows. + mode (str): ``"batch"`` or ``"row"``. + ``"batch"`` randomly zeroes the entire input, ``"row"`` zeroes + randomly selected rows from the batch. + p (float): probability of the input to be zeroed. + training: apply dropout if is ``True``. Default: ``True`` + + Returns: + Tensor[N, ...]: The randomly zeroed tensor. + """ + if p < 0.0 or p > 1.0: + raise ValueError("drop probability has to be between 0 and 1, but got {}".format(p)) + if not training or p == 0.0: + return input + + survival_rate = 1.0 - p + if mode == "batch": + keep = torch.rand(size=(1, ), dtype=input.dtype, device=input.device) < survival_rate + elif mode == "row": + keep = torch.rand(size=(input.size(0),), dtype=input.dtype, device=input.device) < survival_rate + keep = keep[(None, ) * (input.ndim - 1)].T + else: + raise ValueError("mode has to be either 'batch' or 'row', but got {}".format(mode)) + return input / survival_rate * keep + + +class StochasticDepth(nn.Module): + """ + See :func:`stochastic_depth`. + """ + def __init__(self, mode: str, p: float): + super().__init__() + self.mode = mode + self.p = p + + def forward(self, input: Tensor) -> Tensor: + return stochastic_depth(input, self.mode, self.p, self.training) + + def __repr__(self) -> str: + tmpstr = self.__class__.__name__ + '(' + tmpstr += 'mode=' + str(self.mode) + tmpstr += ', p=' + str(self.p) + tmpstr += ')' + return tmpstr From 4170c075917adb74a127837d887a96d086bc4048 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 20 Aug 2021 12:15:14 +0100 Subject: [PATCH 02/10] Adding tests --- test/test_ops.py | 35 +++++++++++++++++++++++++++++ torchvision/ops/stochastic_depth.py | 2 +- 2 files changed, 36 insertions(+), 1 deletion(-) diff --git a/test/test_ops.py b/test/test_ops.py index 5c2fc882902..c64c8950ba1 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -2,6 +2,7 @@ import math from abc import ABC, abstractmethod import pytest +import random import numpy as np @@ -13,6 +14,11 @@ from torchvision import ops from typing import Tuple +try: + from scipy import stats +except ImportError: + stats = None + class RoIOpTester(ABC): dtype = torch.float64 @@ -1000,5 +1006,34 @@ def gen_iou_check(box, expected, tolerance=1e-4): gen_iou_check(box_tensor, expected, tolerance=0.002 if dtype == torch.float16 else 1e-3) +class TestStochasticDepth: + @pytest.mark.skipif(stats is None, reason="scipy.stats not available") + @pytest.mark.parametrize('mode', ["batch", "row"]) + @pytest.mark.parametrize('p', [0.2, 0.5, 0.8]) + def test_stochastic_depth(self, mode, p): + random.seed(42) + batch_size = 5 + x = torch.ones(size=(batch_size, 3, 4, 4)) + layer = ops.StochasticDepth(mode=mode, p=p).to(device=x.device, dtype=x.dtype) + layer.__repr__() + + trials = 250 + num_samples = 0 + counts = 0 + for _ in range(trials): + out = layer(x) + non_zero_count = out.sum(dim=(1, 2, 3)).nonzero().size(0) + if mode == "batch": + if non_zero_count == 0: + counts += 1 + num_samples += 1 + elif mode == "row": + counts += batch_size - non_zero_count + num_samples += batch_size + + p_value = stats.binom_test(counts, num_samples, p=p) + assert p_value > 0.0001 + + if __name__ == '__main__': pytest.main([__file__]) diff --git a/torchvision/ops/stochastic_depth.py b/torchvision/ops/stochastic_depth.py index de30abaa36b..a9b4338aa73 100644 --- a/torchvision/ops/stochastic_depth.py +++ b/torchvision/ops/stochastic_depth.py @@ -40,7 +40,7 @@ class StochasticDepth(nn.Module): """ See :func:`stochastic_depth`. """ - def __init__(self, mode: str, p: float): + def __init__(self, mode: str, p: float) -> None: super().__init__() self.mode = mode self.p = p From 9c8de0e00c69eed2de2e5fd51d18eafe8cf97af2 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 20 Aug 2021 13:17:26 +0100 Subject: [PATCH 03/10] switching order of `p` and `mode`. --- test/test_ops.py | 4 ++-- torchvision/ops/stochastic_depth.py | 14 +++++++------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index c64c8950ba1..c0a9b518afe 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1008,13 +1008,13 @@ def gen_iou_check(box, expected, tolerance=1e-4): class TestStochasticDepth: @pytest.mark.skipif(stats is None, reason="scipy.stats not available") - @pytest.mark.parametrize('mode', ["batch", "row"]) @pytest.mark.parametrize('p', [0.2, 0.5, 0.8]) + @pytest.mark.parametrize('mode', ["batch", "row"]) def test_stochastic_depth(self, mode, p): random.seed(42) batch_size = 5 x = torch.ones(size=(batch_size, 3, 4, 4)) - layer = ops.StochasticDepth(mode=mode, p=p).to(device=x.device, dtype=x.dtype) + layer = ops.StochasticDepth(p=p, mode=mode).to(device=x.device, dtype=x.dtype) layer.__repr__() trials = 250 diff --git a/torchvision/ops/stochastic_depth.py b/torchvision/ops/stochastic_depth.py index a9b4338aa73..2231064d7e0 100644 --- a/torchvision/ops/stochastic_depth.py +++ b/torchvision/ops/stochastic_depth.py @@ -2,7 +2,7 @@ from torch import nn, Tensor -def stochastic_depth(input: Tensor, mode: str, p: float, training: bool = True) -> Tensor: +def stochastic_depth(input: Tensor, p: float, mode: str, training: bool = True) -> Tensor: """ Implements the Stochastic Depth from `"Deep Networks with Stochastic Depth" `_ used for randomly dropping residual @@ -11,10 +11,10 @@ def stochastic_depth(input: Tensor, mode: str, p: float, training: bool = True) Args: input (Tensor[N, ...]): The input tensor or arbitrary dimensions with the first one being its batch i.e. a batch with ``N`` rows. + p (float): probability of the input to be zeroed. mode (str): ``"batch"`` or ``"row"``. ``"batch"`` randomly zeroes the entire input, ``"row"`` zeroes randomly selected rows from the batch. - p (float): probability of the input to be zeroed. training: apply dropout if is ``True``. Default: ``True`` Returns: @@ -40,17 +40,17 @@ class StochasticDepth(nn.Module): """ See :func:`stochastic_depth`. """ - def __init__(self, mode: str, p: float) -> None: + def __init__(self, p: float, mode: str) -> None: super().__init__() - self.mode = mode self.p = p + self.mode = mode def forward(self, input: Tensor) -> Tensor: - return stochastic_depth(input, self.mode, self.p, self.training) + return stochastic_depth(input, self.p, self.mode, self.training) def __repr__(self) -> str: tmpstr = self.__class__.__name__ + '(' - tmpstr += 'mode=' + str(self.mode) - tmpstr += ', p=' + str(self.p) + tmpstr += 'p=' + str(self.p) + tmpstr += ', mode=' + str(self.mode) tmpstr += ')' return tmpstr From bc58919e7cf972c411df21f42d2bcebecd96230f Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 20 Aug 2021 13:23:07 +0100 Subject: [PATCH 04/10] Remove seed setting. --- test/test_ops.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index c0a9b518afe..59bd08aae2f 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -2,7 +2,6 @@ import math from abc import ABC, abstractmethod import pytest -import random import numpy as np @@ -1011,7 +1010,6 @@ class TestStochasticDepth: @pytest.mark.parametrize('p', [0.2, 0.5, 0.8]) @pytest.mark.parametrize('mode', ["batch", "row"]) def test_stochastic_depth(self, mode, p): - random.seed(42) batch_size = 5 x = torch.ones(size=(batch_size, 3, 4, 4)) layer = ops.StochasticDepth(p=p, mode=mode).to(device=x.device, dtype=x.dtype) From 27e0b7b995582f77eaee05ac901e28120a89dfec Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 20 Aug 2021 13:30:03 +0100 Subject: [PATCH 05/10] Replace stats import with pytest.importorskip. --- test/test_ops.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 59bd08aae2f..c64ba1fd0bb 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -13,11 +13,6 @@ from torchvision import ops from typing import Tuple -try: - from scipy import stats -except ImportError: - stats = None - class RoIOpTester(ABC): dtype = torch.float64 @@ -1006,10 +1001,10 @@ def gen_iou_check(box, expected, tolerance=1e-4): class TestStochasticDepth: - @pytest.mark.skipif(stats is None, reason="scipy.stats not available") @pytest.mark.parametrize('p', [0.2, 0.5, 0.8]) @pytest.mark.parametrize('mode', ["batch", "row"]) def test_stochastic_depth(self, mode, p): + stats = pytest.importorskip("scipy.stats") batch_size = 5 x = torch.ones(size=(batch_size, 3, 4, 4)) layer = ops.StochasticDepth(p=p, mode=mode).to(device=x.device, dtype=x.dtype) From 4f839598e99e792be96194c69637444be0925b79 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 20 Aug 2021 13:32:35 +0100 Subject: [PATCH 06/10] Fix doc --- torchvision/ops/stochastic_depth.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/ops/stochastic_depth.py b/torchvision/ops/stochastic_depth.py index 2231064d7e0..35e395f82b9 100644 --- a/torchvision/ops/stochastic_depth.py +++ b/torchvision/ops/stochastic_depth.py @@ -15,7 +15,7 @@ def stochastic_depth(input: Tensor, p: float, mode: str, training: bool = True) mode (str): ``"batch"`` or ``"row"``. ``"batch"`` randomly zeroes the entire input, ``"row"`` zeroes randomly selected rows from the batch. - training: apply dropout if is ``True``. Default: ``True`` + training: apply stochastic depth if is ``True``. Default: ``True`` Returns: Tensor[N, ...]: The randomly zeroed tensor. From 2ca25a9bda54a1de6aad761ee5d7d529bc0e81a4 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 20 Aug 2021 13:35:09 +0100 Subject: [PATCH 07/10] Apply suggestions from code review Co-authored-by: Francisco Massa --- torchvision/ops/stochastic_depth.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/torchvision/ops/stochastic_depth.py b/torchvision/ops/stochastic_depth.py index 35e395f82b9..c3cff61a937 100644 --- a/torchvision/ops/stochastic_depth.py +++ b/torchvision/ops/stochastic_depth.py @@ -26,14 +26,14 @@ def stochastic_depth(input: Tensor, p: float, mode: str, training: bool = True) return input survival_rate = 1.0 - p - if mode == "batch": - keep = torch.rand(size=(1, ), dtype=input.dtype, device=input.device) < survival_rate - elif mode == "row": - keep = torch.rand(size=(input.size(0),), dtype=input.dtype, device=input.device) < survival_rate - keep = keep[(None, ) * (input.ndim - 1)].T - else: + if mode not in ["batch", "row"]: raise ValueError("mode has to be either 'batch' or 'row', but got {}".format(mode)) - return input / survival_rate * keep + size = [1] * input.ndim + if mode == "row": + size[0] = input.shape[0] + noise = torch.empty(size, dtype=input.dtype, device=input.device) + noise = noise.bernoulli_(survival_rate).div_(survival_rate) + return input * noise class StochasticDepth(nn.Module): From be0bf043092e5ceea03947cdcb58b0ce9e8883a4 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 20 Aug 2021 13:38:17 +0100 Subject: [PATCH 08/10] Fixing indentation. --- torchvision/ops/stochastic_depth.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchvision/ops/stochastic_depth.py b/torchvision/ops/stochastic_depth.py index c3cff61a937..e34025d7d9f 100644 --- a/torchvision/ops/stochastic_depth.py +++ b/torchvision/ops/stochastic_depth.py @@ -31,9 +31,9 @@ def stochastic_depth(input: Tensor, p: float, mode: str, training: bool = True) size = [1] * input.ndim if mode == "row": size[0] = input.shape[0] - noise = torch.empty(size, dtype=input.dtype, device=input.device) - noise = noise.bernoulli_(survival_rate).div_(survival_rate) - return input * noise + noise = torch.empty(size, dtype=input.dtype, device=input.device) + noise = noise.bernoulli_(survival_rate).div_(survival_rate) + return input * noise class StochasticDepth(nn.Module): From d050bb91a39fd84399d2f44b10a29118fc7fb110 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 20 Aug 2021 13:40:34 +0100 Subject: [PATCH 09/10] Adding operator in the documentation. --- docs/source/ops.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/source/ops.rst b/docs/source/ops.rst index cdebe9721c3..ecef74dd8a6 100644 --- a/docs/source/ops.rst +++ b/docs/source/ops.rst @@ -23,6 +23,7 @@ torchvision.ops .. autofunction:: ps_roi_pool .. autofunction:: deform_conv2d .. autofunction:: sigmoid_focal_loss +.. autofunction:: stochastic_depth .. autoclass:: RoIAlign .. autoclass:: PSRoIAlign @@ -31,3 +32,4 @@ torchvision.ops .. autoclass:: DeformConv2d .. autoclass:: MultiScaleRoIAlign .. autoclass:: FeaturePyramidNetwork +.. autoclass:: StochasticDepth From ae9be89c34bb1a77986a68017dc3a1fbc1d6f184 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 20 Aug 2021 13:43:34 +0100 Subject: [PATCH 10/10] Fixing lint --- torchvision/ops/stochastic_depth.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/ops/stochastic_depth.py b/torchvision/ops/stochastic_depth.py index e34025d7d9f..f3338242a76 100644 --- a/torchvision/ops/stochastic_depth.py +++ b/torchvision/ops/stochastic_depth.py @@ -30,7 +30,7 @@ def stochastic_depth(input: Tensor, p: float, mode: str, training: bool = True) raise ValueError("mode has to be either 'batch' or 'row', but got {}".format(mode)) size = [1] * input.ndim if mode == "row": - size[0] = input.shape[0] + size[0] = input.shape[0] noise = torch.empty(size, dtype=input.dtype, device=input.device) noise = noise.bernoulli_(survival_rate).div_(survival_rate) return input * noise