Skip to content

Added available_device in test_classification_report (#3335) #3342

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Mar 23, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
259 changes: 146 additions & 113 deletions tests/ignite/metrics/test_classification_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,72 +10,139 @@
from ignite.metrics.classification_report import ClassificationReport


def _test_integration_multiclass(device, output_dict):
rank = idist.get_rank()
def _test_multiclass(metric_device, n_classes, output_dict, labels=None, distributed=False):
if distributed:
device = idist.device()
else:
device = metric_device

classification_report = ClassificationReport(device=metric_device, output_dict=output_dict, labels=labels)
n_iters = 80
batch_size = 16

y_true = torch.randint(0, n_classes, size=(n_iters * batch_size,)).to(device)
y_preds = torch.rand(n_iters * batch_size, n_classes).to(device)

def update(engine, i):
return (
y_preds[i * batch_size : (i + 1) * batch_size, :],
y_true[i * batch_size : (i + 1) * batch_size],
)

engine = Engine(update)

def _test(metric_device, n_classes, labels=None):
classification_report = ClassificationReport(device=metric_device, output_dict=output_dict, labels=labels)
n_iters = 80
batch_size = 16
classification_report.attach(engine, "cr")

y_true = torch.randint(0, n_classes, size=(n_iters * batch_size,)).to(device)
y_preds = torch.rand(n_iters * batch_size, n_classes).to(device)
data = list(range(n_iters))
engine.run(data=data)

def update(engine, i):
return (
y_preds[i * batch_size : (i + 1) * batch_size, :],
y_true[i * batch_size : (i + 1) * batch_size],
)
if distributed:
y_preds = idist.all_gather(y_preds)
y_true = idist.all_gather(y_true)

assert "cr" in engine.state.metrics
res = engine.state.metrics["cr"]
res2 = classification_report.compute()
assert res == res2

assert isinstance(res, dict if output_dict else str)
if not output_dict:
res = json.loads(res)

from sklearn.metrics import classification_report as sklearn_classification_report

sklearn_result = sklearn_classification_report(
y_true.cpu().numpy(), torch.argmax(y_preds, dim=1).cpu().numpy(), output_dict=True, zero_division=1
)

for i in range(n_classes):
label_i = labels[i] if labels else str(i)
assert sklearn_result[str(i)]["precision"] == pytest.approx(res[label_i]["precision"])
assert sklearn_result[str(i)]["f1-score"] == pytest.approx(res[label_i]["f1-score"])
assert sklearn_result[str(i)]["recall"] == pytest.approx(res[label_i]["recall"])
assert sklearn_result["macro avg"]["precision"] == pytest.approx(res["macro avg"]["precision"])
assert sklearn_result["macro avg"]["recall"] == pytest.approx(res["macro avg"]["recall"])
assert sklearn_result["macro avg"]["f1-score"] == pytest.approx(res["macro avg"]["f1-score"])

metric_state = classification_report.state_dict()
classification_report.reset()
classification_report.load_state_dict(metric_state)

res2 = classification_report.compute()
if not output_dict:
res2 = json.loads(res2)

for i in range(n_classes):
label_i = labels[i] if labels else str(i)
assert res2[label_i]["precision"] == res[label_i]["precision"]
assert res2[label_i]["f1-score"] == res[label_i]["f1-score"]
assert res2[label_i]["recall"] == res[label_i]["recall"]
assert res2["macro avg"]["precision"] == res["macro avg"]["precision"]
assert res2["macro avg"]["recall"] == res["macro avg"]["recall"]
assert res2["macro avg"]["f1-score"] == res["macro avg"]["f1-score"]


def _test_multilabel(metric_device, n_epochs, output_dict, labels=None, distributed=False):
if distributed:
device = idist.device()
else:
device = metric_device

classification_report = ClassificationReport(device=metric_device, output_dict=output_dict, is_multilabel=True)

n_iters = 10
batch_size = 16
n_classes = 7

y_true = torch.randint(0, 2, size=(n_iters * batch_size, n_classes, 6, 8)).to(device)
y_preds = torch.randint(0, 2, size=(n_iters * batch_size, n_classes, 6, 8)).to(device)

def update(engine, i):
return (
y_preds[i * batch_size : (i + 1) * batch_size, ...],
y_true[i * batch_size : (i + 1) * batch_size, ...],
)

engine = Engine(update)
engine = Engine(update)

classification_report.attach(engine, "cr")
classification_report.attach(engine, "cr")

data = list(range(n_iters))
engine.run(data=data)
data = list(range(n_iters))
engine.run(data=data, max_epochs=n_epochs)

if distributed:
y_preds = idist.all_gather(y_preds)
y_true = idist.all_gather(y_true)

assert "cr" in engine.state.metrics
res = engine.state.metrics["cr"]
res2 = classification_report.compute()
assert res == res2
assert "cr" in engine.state.metrics
res = engine.state.metrics["cr"]
res2 = classification_report.compute()
assert res == res2

assert isinstance(res, dict if output_dict else str)
if not output_dict:
res = json.loads(res)
assert isinstance(res, dict if output_dict else str)
if not output_dict:
res = json.loads(res)

from sklearn.metrics import classification_report as sklearn_classification_report
np_y_preds = to_numpy_multilabel(y_preds)
np_y_true = to_numpy_multilabel(y_true)

sklearn_result = sklearn_classification_report(
y_true.cpu().numpy(), torch.argmax(y_preds, dim=1).cpu().numpy(), output_dict=True, zero_division=1
)
from sklearn.metrics import classification_report as sklearn_classification_report

sklearn_result = sklearn_classification_report(np_y_true, np_y_preds, output_dict=True, zero_division=1)

for i in range(n_classes):
label_i = labels[i] if labels else str(i)
assert sklearn_result[str(i)]["precision"] == pytest.approx(res[label_i]["precision"])
assert sklearn_result[str(i)]["f1-score"] == pytest.approx(res[label_i]["f1-score"])
assert sklearn_result[str(i)]["recall"] == pytest.approx(res[label_i]["recall"])
assert sklearn_result["macro avg"]["precision"] == pytest.approx(res["macro avg"]["precision"])
assert sklearn_result["macro avg"]["recall"] == pytest.approx(res["macro avg"]["recall"])
assert sklearn_result["macro avg"]["f1-score"] == pytest.approx(res["macro avg"]["f1-score"])

metric_state = classification_report.state_dict()
classification_report.reset()
classification_report.load_state_dict(metric_state)
res2 = classification_report.compute()
if not output_dict:
res2 = json.loads(res2)

for i in range(n_classes):
label_i = labels[i] if labels else str(i)
assert res2[label_i]["precision"] == res[label_i]["precision"]
assert res2[label_i]["f1-score"] == res[label_i]["f1-score"]
assert res2[label_i]["recall"] == res[label_i]["recall"]
assert res2["macro avg"]["precision"] == res["macro avg"]["precision"]
assert res2["macro avg"]["recall"] == res["macro avg"]["recall"]
assert res2["macro avg"]["f1-score"] == res["macro avg"]["f1-score"]
for i in range(n_classes):
label_i = labels[i] if labels else str(i)
assert sklearn_result[str(i)]["precision"] == pytest.approx(res[label_i]["precision"])
assert sklearn_result[str(i)]["f1-score"] == pytest.approx(res[label_i]["f1-score"])
assert sklearn_result[str(i)]["recall"] == pytest.approx(res[label_i]["recall"])
assert sklearn_result["macro avg"]["precision"] == pytest.approx(res["macro avg"]["precision"])
assert sklearn_result["macro avg"]["recall"] == pytest.approx(res["macro avg"]["recall"])
assert sklearn_result["macro avg"]["f1-score"] == pytest.approx(res["macro avg"]["f1-score"])


def _test_integration_multiclass(device, output_dict):
rank = idist.get_rank()
labels = ["label0", "label1", "label2", "label3"]

for i in range(5):
torch.manual_seed(12 + rank + i)
Expand All @@ -84,79 +151,45 @@ def update(engine, i):
if device.type != "xla":
metric_devices.append(idist.device())
for metric_device in metric_devices:
_test(metric_device, 2, ["label0", "label1"])
_test(metric_device, 2)
_test(metric_device, 3, ["label0", "label1", "label2"])
_test(metric_device, 3)
_test(metric_device, 4, ["label0", "label1", "label2", "label3"])
_test(metric_device, 4)
for n_classes in range(2, len(labels) + 1):
for output_dict in [False, True]:
_test_multiclass(metric_device, n_classes, output_dict, distributed=True)
_test_multiclass(metric_device, n_classes, output_dict, labels=labels[:n_classes], distributed=True)


def _test_integration_multilabel(device, output_dict):
rank = idist.get_rank()

def _test(metric_device, n_epochs, labels=None):
classification_report = ClassificationReport(device=metric_device, output_dict=output_dict, is_multilabel=True)

n_iters = 10
batch_size = 16
n_classes = 7

y_true = torch.randint(0, 2, size=(n_iters * batch_size, n_classes, 6, 8)).to(device)
y_preds = torch.randint(0, 2, size=(n_iters * batch_size, n_classes, 6, 8)).to(device)

def update(engine, i):
return (
y_preds[i * batch_size : (i + 1) * batch_size, ...],
y_true[i * batch_size : (i + 1) * batch_size, ...],
)

engine = Engine(update)

classification_report.attach(engine, "cr")

data = list(range(n_iters))
engine.run(data=data, max_epochs=n_epochs)

y_preds = idist.all_gather(y_preds)
y_true = idist.all_gather(y_true)

assert "cr" in engine.state.metrics
res = engine.state.metrics["cr"]
res2 = classification_report.compute()
assert res == res2

assert isinstance(res, dict if output_dict else str)
if not output_dict:
res = json.loads(res)

np_y_preds = to_numpy_multilabel(y_preds)
np_y_true = to_numpy_multilabel(y_true)

from sklearn.metrics import classification_report as sklearn_classification_report

sklearn_result = sklearn_classification_report(np_y_true, np_y_preds, output_dict=True, zero_division=1)

for i in range(n_classes):
label_i = labels[i] if labels else str(i)
assert sklearn_result[str(i)]["precision"] == pytest.approx(res[label_i]["precision"])
assert sklearn_result[str(i)]["f1-score"] == pytest.approx(res[label_i]["f1-score"])
assert sklearn_result[str(i)]["recall"] == pytest.approx(res[label_i]["recall"])
assert sklearn_result["macro avg"]["precision"] == pytest.approx(res["macro avg"]["precision"])
assert sklearn_result["macro avg"]["recall"] == pytest.approx(res["macro avg"]["recall"])
assert sklearn_result["macro avg"]["f1-score"] == pytest.approx(res["macro avg"]["f1-score"])

for i in range(3):
torch.manual_seed(12 + rank + i)
# check multiple random inputs as random exact occurencies are rare
metric_devices = ["cpu"]
if device.type != "xla":
metric_devices.append(idist.device())
for metric_device in metric_devices:
_test(metric_device, 1)
_test(metric_device, 2)
_test(metric_device, 1, ["0", "1", "2", "3", "4", "5", "6"])
_test(metric_device, 2, ["0", "1", "2", "3", "4", "5", "6"])
for n_epochs in [1, 2]:
for output_dict in [False, True]:
_test_multilabel(metric_device, n_epochs, output_dict, distributed=True)
_test_multilabel(
metric_device, n_epochs, output_dict, ["0", "1", "2", "3", "4", "5", "6"], distributed=True
)


@pytest.mark.parametrize("n_times", range(5))
def test_compute_multiclass(n_times, available_device):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's refactor _test_integration_multiclass test such that we could call it in usual and distributed configs. It is not good to copy-paste the code. Same for the second test method.

Copy link
Contributor Author

@HyeSungP HyeSungP Mar 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vfdev-5 Thank you for your advice! So, should I refactor _test_integration_multiclass and _test_integration_multilabel to extract the common logic and make it reusable for both the usual and distributed configurations?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, we should extract the common code that could be used with 1) the new test you have written with available_device and 2) distributed test.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vfdev-5 Hello! I have made a new commit. Could you check it for me?

labels = ["label0", "label1", "label2", "label3"]
for n_classes in range(2, len(labels) + 1):
for output_dict in [False, True]:
_test_multiclass(available_device, n_classes, output_dict)
_test_multiclass(available_device, n_classes, output_dict, labels[:n_classes])


@pytest.mark.parametrize("n_times", range(5))
def test_compute_multilabel(n_times, available_device):
for n_epochs in [1, 2]:
for output_dict in [False, True]:
_test_multilabel(available_device, n_epochs, output_dict)
_test_multilabel(available_device, n_epochs, output_dict, ["0", "1", "2", "3", "4", "5", "6"])


@pytest.mark.distributed
Expand Down
Loading