Skip to content

fix _multiclass_stat_scores_update in classification #3078

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from

Conversation

ved1beta
Copy link

@ved1beta ved1beta commented Apr 30, 2025

What does this PR do?

Fixes #3068

This PR fixes an issue with the multiclass accuracy calculation when using top_k > 1 with average="micro". The bug caused incorrect accuracy calculations in scenarios where predictions were provided as logits/probabilities and the correct class needed to be identified among the top-k predictions.

The fix ensures that:

  1. The one-hot encoding path is always used when top_k > 1, regardless of the averaging method
  2. A special case is added in multiclass_accuracy to properly handle top-k with micro averaging
  3. Top-k selection is consistently applied across all evaluation scenarios

The PR includes test cases that demonstrate the issue and verify the fix works correctly. These tests show the alignment between manual calculations of top-k accuracy and the results from the metric implementation.

Before submitting
  • Was this discussed/agreed via a Github issue? (no need for typos and docs improvements)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure to update the docs?
  • Did you write any new necessary tests?
PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

Did you have fun?

Make sure you had fun coding 🙃


📚 Documentation preview 📚: https://torchmetrics--3078.org.readthedocs.build/en/3078/

fixes #3068

Copy link
Contributor

@rittik9 rittik9 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @ved1beta thanks for opening this pr,
but I think these issues are still there even with the updated code

>>> logits
tensor([[[0.0000, 0.1000, 0.5000, 0.4000],
         [0.0000, 0.2000, 0.7000, 0.1000]],

        [[0.0000, 0.4000, 0.3000, 0.3000],
         [1.0000, 0.0000, 0.0000, 0.0000]]])
>>> code
tensor([[3, 2],
        [1, 0]])
>>> logits.shape
torch.Size([2, 2, 4])
>>> code.shape
torch.Size([2, 2])
>>> acc = Accuracy(task="multiclass", ignore_index=0, num_classes=4, multidim_average="global", average="micro", top_k=4)
>>> acc(logits.transpose(2, 1), code)
tensor(0.6667)
>>> acc = Accuracy(task="multiclass", ignore_index=0, num_classes=4, multidim_average="global", average="micro", top_k=3)
>>> acc(logits.transpose(2, 1), code)
tensor(0.6667)
>>> acc = Accuracy(task="multiclass", ignore_index=0, num_classes=4, multidim_average="global", average="micro", top_k=2)
>>> acc(logits.transpose(2, 1), code)
tensor(0.6667)
>>> acc = Accuracy(task="multiclass", ignore_index=0, num_classes=4, multidim_average="global", average="micro", top_k=1)
>>> acc(logits.transpose(2, 1), code)
tensor(0.6667)

can you pls add them as unittests and recheck your implementation...

@Borda Borda changed the title fix _multiclass_stat_scores_update fix _multiclass_stat_scores_update in classification May 20, 2025
@Borda Borda marked this pull request as draft June 10, 2025 10:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Multiclass accuracy with micro and top-k does not work as expected.
3 participants