Skip to content

fix _multiclass_stat_scores_update in classification #3078

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 56 additions & 7 deletions src/torchmetrics/functional/classification/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
from typing import Optional

import torch
from torch import Tensor
from typing_extensions import Literal

Expand Down Expand Up @@ -75,16 +76,26 @@ def _accuracy_reduce(
"""
if average == "binary":
return _safe_divide(tp + tn, tp + tn + fp + fn)

# Calculate base score
score = _safe_divide(tp + tn, tp + tn + fp + fn) if multilabel else _safe_divide(tp, tp + fn)

# For top_k > 1, always use the adjust_weights function which properly handles top_k
if top_k > 1:
return _adjust_weights_safe_divide(score, average, multilabel, tp, fp, fn, top_k)

# For top_k=1, continue with the original logic
if average == "micro":
tp = tp.sum(dim=0 if multidim_average == "global" else 1)
fn = fn.sum(dim=0 if multidim_average == "global" else 1)
# Apply sum before returning for micro averaging
tp_sum = tp.sum(dim=0 if multidim_average == "global" else 1)
fn_sum = fn.sum(dim=0 if multidim_average == "global" else 1)
if multilabel:
fp = fp.sum(dim=0 if multidim_average == "global" else 1)
tn = tn.sum(dim=0 if multidim_average == "global" else 1)
return _safe_divide(tp + tn, tp + tn + fp + fn)
return _safe_divide(tp, tp + fn)
fp_sum = fp.sum(dim=0 if multidim_average == "global" else 1)
tn_sum = tn.sum(dim=0 if multidim_average == "global" else 1)
return _safe_divide(tp_sum + tn_sum, tp_sum + tn_sum + fp_sum + fn_sum)
return _safe_divide(tp_sum, tp_sum + fn_sum)

score = _safe_divide(tp + tn, tp + tn + fp + fn) if multilabel else _safe_divide(tp, tp + fn)
# For other averaging methods, apply the adjustment
return _adjust_weights_safe_divide(score, average, multilabel, tp, fp, fn, top_k)


Expand Down Expand Up @@ -264,6 +275,44 @@ def multiclass_accuracy(
if validate_args:
_multiclass_stat_scores_arg_validation(num_classes, top_k, average, multidim_average, ignore_index)
_multiclass_stat_scores_tensor_validation(preds, target, num_classes, multidim_average, ignore_index)

if top_k > 1 and average == "micro" and preds.ndim == target.ndim + 1:
if preds.ndim == target.ndim:
num_classes = num_classes or (target.max().int().item() + 1)
preds = torch.nn.functional.one_hot(preds, num_classes).to(preds.dtype)
preds = preds.transpose(1, -1)

if multidim_average == "global":
flat_shape = preds.shape[:2] + (-1,)
flat_preds = preds.reshape(flat_shape)
flat_target = target.reshape(target.shape[0], -1)
else:
flat_shape = preds.shape[:2] + (-1,)
flat_preds = preds.reshape(flat_shape)
flat_target = target.reshape(target.shape[0], -1)

batch_size = flat_target.shape[0]
num_samples = flat_target.shape[1]

if ignore_index is not None:
valid_mask = flat_target != ignore_index
else:
valid_mask = torch.ones_like(flat_target, dtype=torch.bool)

correct_list = []
for i in range(batch_size):
for j in range(num_samples):
if not valid_mask[i, j]:
continue
sample_preds = flat_preds[i, :, j]
sample_target = flat_target[i, j]
_, top_indices = torch.topk(sample_preds, min(top_k, sample_preds.size(0)), dim=0)
correct_list.append(torch.any(top_indices == sample_target).int())

if correct_list:
return torch.stack(correct_list).float().mean()
return torch.tensor(0.0, device=preds.device)

preds, target = _multiclass_stat_scores_format(preds, target, top_k)
tp, fp, tn, fn = _multiclass_stat_scores_update(
preds, target, num_classes or 1, top_k, average, multidim_average, ignore_index
Expand Down
17 changes: 14 additions & 3 deletions src/torchmetrics/functional/classification/stat_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,15 +381,24 @@ def _multiclass_stat_scores_update(
) -> tuple[Tensor, Tensor, Tensor, Tensor]:
"""Compute the statistics.

- If ``multidim_average`` is equal to samplewise or ``top_k`` is not 1, we transform both preds and
target into one hot format.
- If ``multidim_average`` is equal to samplewise or ``top_k`` is greater than 1, we transform both preds and
target into one hot format to properly handle top-k predictions.
- Else we calculate statistics by first calculating the confusion matrix and afterwards deriving the
statistics from that
- Remove all datapoints that should be ignored. Depending on if ``ignore_index`` is in the set of labels
or outside we have do use different augmentation strategies when one hot encoding.

Notes:
- For top_k > 1, we always use the one-hot encoding path regardless of the averaging method
to ensure top-k logic is properly applied in all cases, including micro averaging.

"""
if multidim_average == "samplewise" or top_k != 1:
# Modified condition to always use one-hot path when top_k > 1, regardless of average method
if multidim_average == "samplewise" or top_k > 1 or (preds.ndim == target.ndim + 1 and average == "micro"):
# Always use one-hot encoding for:
# 1. samplewise averaging
# 2. top_k > 1
# 3. when inputs have different dimensions (probably logits vs. class indices) and micro averaging
ignore_in = 0 <= ignore_index <= num_classes - 1 if ignore_index is not None else None
if ignore_index is not None and not ignore_in:
preds = preds.clone()
Expand All @@ -400,9 +409,11 @@ def _multiclass_stat_scores_update(
preds[idx] = num_classes

if top_k > 1:
# For top_k > 1, we need to get the top-k predictions in one-hot format
preds_oh = torch.movedim(select_topk(preds, topk=top_k, dim=1), 1, -1)
preds_oh = _refine_preds_oh(preds, preds_oh, target, top_k)
else:
# Otherwise just one-hot encode the class indices
preds_oh = torch.nn.functional.one_hot(
preds.long(), num_classes + 1 if ignore_index is not None and not ignore_in else num_classes
)
Expand Down
Loading