Description
The preferred way to structure the Transformation classes is to put the initialization of random weights/params in a static get_params()
method. The method should receive any hyper parameter necessary for the sampling and it should return all the necessary random variables. This method should be called by forward()
during the transformation process. This is an example of how this would look:
vision/torchvision/transforms/transforms.py
Lines 530 to 531 in 9e71fda
vision/torchvision/transforms/transforms.py
Line 589 in 9e71fda
Unfortunately many forward()
methods call directly torch.rand
. Here are a few examples:
vision/torchvision/transforms/transforms.py
Lines 452 to 453 in 9e71fda
vision/torchvision/transforms/transforms.py
Line 619 in 9e71fda
vision/torchvision/transforms/transforms.py
Line 649 in 9e71fda
vision/torchvision/transforms/transforms.py
Line 700 in 9e71fda
vision/torchvision/transforms/transforms.py
Line 1454 in 9e71fda
vision/torchvision/transforms/transforms.py
Line 1560 in 9e71fda
There might be potentially others. We should refactor the codebase so that all of the above calls happen within a static get_params()
method. See #3065 RandomInvert
for an example on how to structure it.
cc @vfdev-5