Description
The goal of this issue is two-fold:
- Collect user feedback on some specific design decisions regarding transforms V2. It complements [FEEDBACK] Transforms V2 API #6753 which is for organic / general feedback.
- Document upfront which areas of the new transforms may change in the future without a deprecation cycle (the new transforms are still Beta!).
We'll detail each of those topics below. Please share any feedback or suggestion you may have to help us provide the most useful APIs!
Subclass (un)wrapping
All tensor operations on a datapoint currently lose the datapoint class and return a pure tensor instead. We call this mechanism "subclass unwrapping":
img1 = datapoints.Image(torch.rand(3, 224, 224))
img2 = datapoints.Image(torch.rand(3, 224, 224))
img3 = img1 + img2 # img3 is not an Image tensor anymore, it's a pure tensor!
assert isinstance(img3, torch.Tensor) and not isinstance(img3, datapoints.Image)
# The same is true for datapoints.Video, datapoints.BoundingBox, datapoints.Mask, etc.
The reason we currently unwrap the datapoints is because some of them (e.g. bouding boxes) come with extra meta-data attached to them like the bbox format, and there is currently no protocol to pass that meta-data down to the output result. The second reason is that in some cases it's impossible to know whether the result of the operation is still a valid datapoint: in the example above, can we still consider img3
to be a valid Image
?
We acknowledge that this unwrapping behaviour may seem surprising and unexpected in some cases. E.g. for datapoints that don't have meta-data (hence the first reason doesn't apply here), once could argue that it's up to the user to decide whether the datapoint is still valid or not; following that argument, we could potentially always return Images
or Videos
since they don't (currently) have any meta-data. We could also think of way to "force subclass wrapping", e.g. through a context manager like
with force_subclass_wrapping():
img3 = img1 + img2
assert isinstance(img3, torch.Tensor) and isinstance(img3, datapoints.Image)
Let us know what you think!
Bounding box clamping
Currently, all transforms that may potentially operate on a bounding box will automatically clamp that bounding box to its corresponding image dimensions. Whether the transforms should clamp or not clamp by default is up for discussion. We could also let users choose (by adding a new parameter to all of those tranfsorms?)
Enforce a single BoundingBox instance in all transforms?
Right now, some transforms allow for multiple BoundingBox
instances to be present in the input samples, while others will raise an error. We may consider enforcing one unique BoundingBox
instance for all transforms in the future.
(Note that a single BoundingBox
instance may still contains multiple bounding boxes!)
How to handle labels?
Sometimes, some bounding boxes become degenerate after a tranformation and we need a way to remove them, along with their associated labels.
Labels are tricky because they can refer to different things: an image, a bouding box, or a mask. In a previous design we had a special Label
datapoint subclass but we decided to not release it for now, because of the ambiguity of what they should refer to: if we have a sample like img, bbox, label
, how do we know whether the label is for the image, or for the bounding box?
For this reason we have currently decided to not have a Label
datapoint class, and instead let labels be pure tensors (or ints) and pass them through all of the transforms. The only transform that can handle label is SanitizeBoundingBox
, which asks users to manually specify which entries in the input correspond to the labels: so there's no need to guess anymore, and no ambiguity.
We're still considering changing this and potentially bring back a Label subclass (this related to another point in this issue about pairwise transforms which may need a Label subclass).
We're also considering the alternative of not having a Labels subclass, and instead let the label be a meta-data attached to the datapoints: e.g. the Image class could have a label meta-data, and so could the BoundingBox
class.
Your input on the subject would be valuable.
How to smoothly support "pairwise" transforms?
There are a few critically useful transforms that operate on pairs of samples instead of operating on a single sample: CutMix, MixUp, etc. Because of their fundamentally different behaviour, they tend to be (and currently are) implemented as collation function to be passed to the DataLoader
, and so they cannot be used like the rest of the transforms, which makes them harder to use.
Those transforms also need to tranform the labels on top of the input images, and we're still trying to figure out the smoothest way to handle labels (see other point in this issue).
For these reasons we have currently left those in the prototype area as we're still aiming to polish their APIs.
One option we are dicussing (but it is far from finalized) is to implement those transforms as stateful transforms, to allow them to be used like regular transforms. Something roughly along those lines:
class MixUp():
def forward(self, img, label):
out = _mixup_pair(img, label, self._prev_img, self._prev_label)
self._prev_img, self._prev_label = img, label
return out
Whether this is a good or a terrible idea is still up for discussion!
Supporting user-defined datapoints and datapoints methods
Users can already implement custom transforms that are compatible with transforms V2. Implementing a user-defined datapoint is also supported, but we're not too happy with the way we currently enable that support. To enable custom datapoints, we currently override a lot of the transforms as methods on the datapoints classes, e.g. the Image
class has all of the .resize(), .crop(), .rotate()
methods, etc.
This isn't something we're too happy with because it makes the implementation of new transforms cumbersome, and it may also conflict with the Tensor
base-class namespace.
We do not guarantee that we'll keep supporting those methods in the future.
Tensor pass-through heuristic
At this time, inputs that aren't datapoints will be passed-through all transforms:
transformed_img, other_stuff = t(img, other_stuff)
# other_stuff is passed-through!
Well, not all non-datapoints inputs are passed-through: we still want the transforms V2 to be fully backward compatible with the V1 transforms, so we still want to treat pure tensors as Images.
For this reason we currently have implemented a (potentially surprising) heuristic:
- If we find an explicit image or video (
Image
,Video
, or PIL Image) in the input, all other plain tensors are passed through. - If there is no explicit image or video, only the first plain tensor will be transformed as image or video, while all others will be passed through.
If this is confusing don't worry, 99% of users don't need to worry about this anyway.
We're considering ways to simplify of even remove this heuristic, e.g. in #7340