Skip to content

Commit a5d3464

Browse files
authored
simplify imports of metric functions (#3292)
1 parent ad02551 commit a5d3464

File tree

2 files changed

+11
-20
lines changed

2 files changed

+11
-20
lines changed

ignite/metrics/cohen_kappa.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -71,23 +71,17 @@ def __init__(
7171
# initalize weights
7272
self.weights = weights
7373

74-
self.cohen_kappa_compute = self.get_cohen_kappa_fn()
75-
7674
super(CohenKappa, self).__init__(
77-
self.cohen_kappa_compute,
75+
self._cohen_kappa_score,
7876
output_transform=output_transform,
7977
check_compute_fn=check_compute_fn,
8078
device=device,
8179
skip_unrolling=skip_unrolling,
8280
)
8381

84-
def get_cohen_kappa_fn(self) -> Callable[[torch.Tensor, torch.Tensor], float]:
85-
"""Return a function computing Cohen Kappa from scikit-learn."""
82+
def _cohen_kappa_score(self, y_targets: torch.Tensor, y_preds: torch.Tensor) -> float:
8683
from sklearn.metrics import cohen_kappa_score
8784

88-
def wrapper(y_targets: torch.Tensor, y_preds: torch.Tensor) -> float:
89-
y_true = y_targets.cpu().numpy()
90-
y_pred = y_preds.cpu().numpy()
91-
return cohen_kappa_score(y_true, y_pred, weights=self.weights)
92-
93-
return wrapper
85+
y_true = y_targets.cpu().numpy()
86+
y_pred = y_preds.cpu().numpy()
87+
return cohen_kappa_score(y_true, y_pred, weights=self.weights)

ignite/metrics/regression/spearman_correlation.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,13 @@
99
from ignite.metrics.regression._base import _check_output_shapes, _check_output_types
1010

1111

12-
def _get_spearman_r() -> Callable[[Tensor, Tensor], float]:
12+
def _spearman_r(predictions: Tensor, targets: Tensor) -> float:
1313
from scipy.stats import spearmanr
1414

15-
def _compute_spearman_r(predictions: Tensor, targets: Tensor) -> float:
16-
np_preds = predictions.flatten().numpy()
17-
np_targets = targets.flatten().numpy()
18-
r = spearmanr(np_preds, np_targets).statistic
19-
return r
20-
21-
return _compute_spearman_r
15+
np_preds = predictions.flatten().numpy()
16+
np_targets = targets.flatten().numpy()
17+
r = spearmanr(np_preds, np_targets).statistic
18+
return r
2219

2320

2421
class SpearmanRankCorrelation(EpochMetric):
@@ -92,7 +89,7 @@ def __init__(
9289
except ImportError:
9390
raise ModuleNotFoundError("This module requires scipy to be installed.")
9491

95-
super().__init__(_get_spearman_r(), output_transform, check_compute_fn, device, skip_unrolling)
92+
super().__init__(_spearman_r, output_transform, check_compute_fn, device, skip_unrolling)
9693

9794
def update(self, output: Tuple[torch.Tensor, torch.Tensor]) -> None:
9895
y_pred, y = output[0].detach(), output[1].detach()

0 commit comments

Comments
 (0)