Description
🚀 Feature
Define an official multi-class focal loss function
Motivation
Most object detectors handle more than 1 class, so a multi-class focal loss function would cover more use-cases than the existing binary focal loss released in v0.8.0
Additionally, there are many different implementations of multi-class focal loss floating around on the web (PyTorch forums, Github, etc). As the authors of the RetinaNet paper, Facebook AI Research should provide a definitive version to settle any existing debates
Pitch
To the best of my understanding, this version by Thomas V. in the PyTorch forums seems correct. Please feel free to correct me if this is not the right approach
import torch
import torch.nn.functional as F
batch_size = 8
num_classes = 5
logits = torch.randn(batch_size, num_classes)
targets = torch.randint(0, num_classes, (batch_size, ))
alpha = 0.25
gamma = 2
ce_loss = F.cross_entropy(logits, targets, reduction='none')
pt = torch.exp(-ce_loss)
focal_loss = alpha * (1 - pt) ** gamma * ce_loss
Alternatives
Individual practitioners continue writing their own
Additional context
The RetinaNet paper doesn't provide any equations to describe multi-class focal loss, so I think that's partially why people currently have varying implementations. In particular alpha_t
is not defined, so I noticed Thomas and other users don't follow the same alpha --> alpha_t
conversion used in torchvision's current binary focal loss implementation