Skip to content

Multi-class focal loss #3250

Open
Open
@addisonklinke

Description

@addisonklinke

🚀 Feature

Define an official multi-class focal loss function

Motivation

Most object detectors handle more than 1 class, so a multi-class focal loss function would cover more use-cases than the existing binary focal loss released in v0.8.0

Additionally, there are many different implementations of multi-class focal loss floating around on the web (PyTorch forums, Github, etc). As the authors of the RetinaNet paper, Facebook AI Research should provide a definitive version to settle any existing debates

Pitch

To the best of my understanding, this version by Thomas V. in the PyTorch forums seems correct. Please feel free to correct me if this is not the right approach

import torch
import torch.nn.functional as F

batch_size = 8
num_classes = 5
logits = torch.randn(batch_size, num_classes)
targets = torch.randint(0, num_classes, (batch_size, ))

alpha = 0.25
gamma = 2
ce_loss = F.cross_entropy(logits, targets, reduction='none')
pt = torch.exp(-ce_loss)
focal_loss = alpha * (1 - pt) ** gamma * ce_loss

Alternatives

Individual practitioners continue writing their own

Additional context

The RetinaNet paper doesn't provide any equations to describe multi-class focal loss, so I think that's partially why people currently have varying implementations. In particular alpha_t is not defined, so I noticed Thomas and other users don't follow the same alpha --> alpha_t conversion used in torchvision's current binary focal loss implementation

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions