Skip to content

[FEEDBACK] TransformsV2: What may change in the future (we need your input!) #7319

Closed
@NicolasHug

Description

@NicolasHug

The goal of this issue is two-fold:

  • Collect user feedback on some specific design decisions regarding transforms V2. It complements [FEEDBACK] Transforms V2 API #6753 which is for organic / general feedback.
  • Document upfront which areas of the new transforms may change in the future without a deprecation cycle (the new transforms are still Beta!).

We'll detail each of those topics below. Please share any feedback or suggestion you may have to help us provide the most useful APIs!

Subclass (un)wrapping

All tensor operations on a datapoint currently lose the datapoint class and return a pure tensor instead. We call this mechanism "subclass unwrapping":

img1 = datapoints.Image(torch.rand(3, 224, 224))
img2 = datapoints.Image(torch.rand(3, 224, 224))
img3 = img1 + img2  # img3 is not an Image tensor anymore, it's a pure tensor!
assert isinstance(img3, torch.Tensor) and not isinstance(img3, datapoints.Image)

# The same is true for datapoints.Video, datapoints.BoundingBox, datapoints.Mask, etc.

The reason we currently unwrap the datapoints is because some of them (e.g. bouding boxes) come with extra meta-data attached to them like the bbox format, and there is currently no protocol to pass that meta-data down to the output result. The second reason is that in some cases it's impossible to know whether the result of the operation is still a valid datapoint: in the example above, can we still consider img3 to be a valid Image?

We acknowledge that this unwrapping behaviour may seem surprising and unexpected in some cases. E.g. for datapoints that don't have meta-data (hence the first reason doesn't apply here), once could argue that it's up to the user to decide whether the datapoint is still valid or not; following that argument, we could potentially always return Images or Videos since they don't (currently) have any meta-data. We could also think of way to "force subclass wrapping", e.g. through a context manager like

with force_subclass_wrapping():
    img3 = img1 + img2
    assert isinstance(img3, torch.Tensor) and isinstance(img3, datapoints.Image)

Let us know what you think!

Bounding box clamping

Currently, all transforms that may potentially operate on a bounding box will automatically clamp that bounding box to its corresponding image dimensions. Whether the transforms should clamp or not clamp by default is up for discussion. We could also let users choose (by adding a new parameter to all of those tranfsorms?)

Enforce a single BoundingBox instance in all transforms?

Right now, some transforms allow for multiple BoundingBox instances to be present in the input samples, while others will raise an error. We may consider enforcing one unique BoundingBox instance for all transforms in the future.

(Note that a single BoundingBox instance may still contains multiple bounding boxes!)

How to handle labels?

Sometimes, some bounding boxes become degenerate after a tranformation and we need a way to remove them, along with their associated labels.

Labels are tricky because they can refer to different things: an image, a bouding box, or a mask. In a previous design we had a special Label datapoint subclass but we decided to not release it for now, because of the ambiguity of what they should refer to: if we have a sample like img, bbox, label, how do we know whether the label is for the image, or for the bounding box?

For this reason we have currently decided to not have a Label datapoint class, and instead let labels be pure tensors (or ints) and pass them through all of the transforms. The only transform that can handle label is SanitizeBoundingBox, which asks users to manually specify which entries in the input correspond to the labels: so there's no need to guess anymore, and no ambiguity.

We're still considering changing this and potentially bring back a Label subclass (this related to another point in this issue about pairwise transforms which may need a Label subclass).

We're also considering the alternative of not having a Labels subclass, and instead let the label be a meta-data attached to the datapoints: e.g. the Image class could have a label meta-data, and so could the BoundingBox class.

Your input on the subject would be valuable.

How to smoothly support "pairwise" transforms?

There are a few critically useful transforms that operate on pairs of samples instead of operating on a single sample: CutMix, MixUp, etc. Because of their fundamentally different behaviour, they tend to be (and currently are) implemented as collation function to be passed to the DataLoader, and so they cannot be used like the rest of the transforms, which makes them harder to use.

Those transforms also need to tranform the labels on top of the input images, and we're still trying to figure out the smoothest way to handle labels (see other point in this issue).

For these reasons we have currently left those in the prototype area as we're still aiming to polish their APIs.

One option we are dicussing (but it is far from finalized) is to implement those transforms as stateful transforms, to allow them to be used like regular transforms. Something roughly along those lines:

class MixUp():
    def forward(self, img, label):
        out = _mixup_pair(img, label, self._prev_img, self._prev_label)
        self._prev_img, self._prev_label = img, label
        return out

Whether this is a good or a terrible idea is still up for discussion!

Supporting user-defined datapoints and datapoints methods

Users can already implement custom transforms that are compatible with transforms V2. Implementing a user-defined datapoint is also supported, but we're not too happy with the way we currently enable that support. To enable custom datapoints, we currently override a lot of the transforms as methods on the datapoints classes, e.g. the Image class has all of the .resize(), .crop(), .rotate() methods, etc.

This isn't something we're too happy with because it makes the implementation of new transforms cumbersome, and it may also conflict with the Tensor base-class namespace.

We do not guarantee that we'll keep supporting those methods in the future.

Tensor pass-through heuristic

At this time, inputs that aren't datapoints will be passed-through all transforms:

transformed_img, other_stuff = t(img, other_stuff)
# other_stuff is passed-through!

Well, not all non-datapoints inputs are passed-through: we still want the transforms V2 to be fully backward compatible with the V1 transforms, so we still want to treat pure tensors as Images.

For this reason we currently have implemented a (potentially surprising) heuristic:

  • If we find an explicit image or video (Image, Video, or PIL Image) in the input, all other plain tensors are passed through.
  • If there is no explicit image or video, only the first plain tensor will be transformed as image or video, while all others will be passed through.

If this is confusing don't worry, 99% of users don't need to worry about this anyway.

We're considering ways to simplify of even remove this heuristic, e.g. in #7340

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions