Skip to content

[RFC] Loss Functions in Torchvision #2980

Open
@oke-aditya

Description

@oke-aditya

🚀 Feature

A loss functions API in torchvision.

Motivation

The request is simple, we have loss functions available in torchvision
E.g. sigmoid_focal_loss , l1_loss. But these are quite scattered and we have to use torchvision.ops.sigmoid_focal_loss etc.

In future, we might need to include further loss functions. E.g. dice_loss

Since loss functions are differentiable we can put them under nn.
We can have

torchvision.nn.losses.sigmoid_focal_loss and so on.

This keeps the scope of nn open for other differentiable functions such as layers, etc.

Pitch

These losses are very specific and pertain to vision domain. These are really useful and in general not tied to any specific model.
Though the loss functions that we keep are usually in torch. If we keep under nn namespace, future migration stays simple.

instead of torchvision.nn.sigmoid_focal_loss it would be torch.nn.sigmoid_focal_loss.

This Pitch comes from the above issues.
More Loss Functions

Alternatives

Alternatively, this should go in torch. But if we keep the above idea, we can support them in torchvision and later deprecate and move to torch (when needed).

Currently, we include them under ops but it is actually not an operation it is a differentiable loss function.

Whereas other ops are not differentiable and perform transformations / some manipulation over boxes/layers.

Additional context

Here is a list of loss functions we would like to include.

References

We can refer to Kornia, Fvcore and few PyTorch issues that need this feature.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions