Skip to content

Avoid calling torch.rand in Transformation.forward() #3066

Closed
@datumbox

Description

@datumbox

The preferred way to structure the Transformation classes is to put the initialization of random weights/params in a static get_params() method. The method should receive any hyper parameter necessary for the sampling and it should return all the necessary random variables. This method should be called by forward() during the transformation process. This is an example of how this would look:

@staticmethod
def get_params(img: Tensor, output_size: Tuple[int, int]) -> Tuple[int, int, int, int]:

i, j, h, w = self.get_params(img, self.size)

Unfortunately many forward() methods call directly torch.rand. Here are a few examples:

def forward(self, img):
if self.p < torch.rand(1):

if torch.rand(1) < self.p:

if torch.rand(1) < self.p:

if torch.rand(1) < self.p:

if torch.rand(1) < self.p:

if torch.rand(1) < self.p:

There might be potentially others. We should refactor the codebase so that all of the above calls happen within a static get_params() method. See #3065 RandomInvert for an example on how to structure it.

cc @vfdev-5

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions