Skip to content

Use f32 for metrics on mps #3334

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Feb 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions .github/workflows/mps-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ concurrency:
group: mps-tests-${{ github.ref_name }}-${{ !(github.ref_protected) || github.sha }}
cancel-in-progress: true

# Cherry-picked from
# Cherry-picked from
# - https://github.com/pytorch/vision/blob/main/.github/workflows/tests.yml
# - https://github.com/pytorch/test-infra/blob/main/.github/workflows/macos_job.yml

Expand All @@ -40,7 +40,7 @@ jobs:
fail-fast: false
runs-on: ["macos-m1-stable"]
timeout-minutes: 60

steps:
- name: Clean workspace
run: |
Expand Down Expand Up @@ -76,15 +76,15 @@ jobs:
run: |
conda shell.bash hook
conda activate $CONDA_ENV
pip install torch torchvision
pip install -U torch torchvision

- name: Install PyTorch (nightly)
if: ${{ matrix.pytorch-channel == 'pytorch-nightly' }}
shell: bash -l {0}
run: |
conda shell.bash hook
conda activate $CONDA_ENV
pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu
pip install --pre -U torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu

- name: Install dependencies
shell: bash -l {0}
Expand Down
2 changes: 1 addition & 1 deletion ignite/metrics/accumulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def __init__(

@reinit__is_reduced
def reset(self) -> None:
self.accumulator = torch.tensor(0.0, dtype=torch.float64, device=self._device)
self.accumulator = torch.tensor(0.0, dtype=self._double_dtype, device=self._device)
self.num_examples = 0

def _check_output_type(self, output: Union[float, torch.Tensor]) -> None:
Expand Down
8 changes: 4 additions & 4 deletions ignite/metrics/gan/fid.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,16 +214,16 @@ def _get_covariance(self, sigma: torch.Tensor, total: torch.Tensor) -> torch.Ten
@reinit__is_reduced
def reset(self) -> None:
self._train_sigma = torch.zeros(
(self._num_features, self._num_features), dtype=torch.float64, device=self._device
(self._num_features, self._num_features), dtype=self._double_dtype, device=self._device
)

self._train_total = torch.zeros(self._num_features, dtype=torch.float64, device=self._device)
self._train_total = torch.zeros(self._num_features, dtype=self._double_dtype, device=self._device)

self._test_sigma = torch.zeros(
(self._num_features, self._num_features), dtype=torch.float64, device=self._device
(self._num_features, self._num_features), dtype=self._double_dtype, device=self._device
)

self._test_total = torch.zeros(self._num_features, dtype=torch.float64, device=self._device)
self._test_total = torch.zeros(self._num_features, dtype=self._double_dtype, device=self._device)
self._num_examples: int = 0

super(FID, self).reset() # type: ignore
Expand Down
8 changes: 4 additions & 4 deletions ignite/metrics/gan/inception_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,20 +103,20 @@ def __init__(
def reset(self) -> None:
self._num_examples = 0

self._prob_total = torch.zeros(self._num_features, dtype=torch.float64, device=self._device)
self._total_kl_d = torch.zeros(self._num_features, dtype=torch.float64, device=self._device)
self._prob_total = torch.zeros(self._num_features, dtype=self._double_dtype, device=self._device)
self._total_kl_d = torch.zeros(self._num_features, dtype=self._double_dtype, device=self._device)

super(InceptionScore, self).reset() # type: ignore

@reinit__is_reduced
def update(self, output: torch.Tensor) -> None:
probabilities = self._extract_features(output)

prob_sum = torch.sum(probabilities, 0, dtype=torch.float64)
prob_sum = torch.sum(probabilities, 0, dtype=self._double_dtype)
log_prob = torch.log(probabilities + self._eps)
if log_prob.dtype != probabilities.dtype:
log_prob = log_prob.to(probabilities)
kl_sum = torch.sum(probabilities * log_prob, 0, dtype=torch.float64)
kl_sum = torch.sum(probabilities * log_prob, 0, dtype=self._double_dtype)

self._num_examples += probabilities.shape[0]
self._prob_total += prob_sum
Expand Down
2 changes: 1 addition & 1 deletion ignite/metrics/gan/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def _extract_features(self, inputs: torch.Tensor) -> torch.Tensor:
inputs = inputs.to(self._device)

with torch.no_grad():
outputs = self._feature_extractor(inputs).to(self._device, dtype=torch.float64)
outputs = self._feature_extractor(inputs).to(self._device, dtype=self._double_dtype)
self._check_feature_shapes(outputs)

return outputs
6 changes: 6 additions & 0 deletions ignite/metrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,12 @@ def __init__(

self._device = torch.device(device)
self._skip_unrolling = skip_unrolling

# MPS framework doesn't support float64, should use float32
self._double_dtype = torch.float64
if self._device.type == "mps":
self._double_dtype = torch.float32

self.reset()

@abstractmethod
Expand Down
2 changes: 1 addition & 1 deletion ignite/metrics/multilabel_confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def compute(self) -> torch.Tensor:
raise NotComputableError("Confusion matrix must have at least one example before it can be computed.")

if self.normalized:
conf = self.confusion_matrix.to(dtype=torch.float64)
conf = self.confusion_matrix.to(dtype=self._double_dtype)
sums = conf.sum(dim=(1, 2))
return conf / sums[:, None, None]

Expand Down
4 changes: 2 additions & 2 deletions ignite/metrics/precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ def _prepare_output(self, output: Sequence[torch.Tensor]) -> Sequence[torch.Tens
y = torch.transpose(y, 1, -1).reshape(-1, num_labels)

# Convert from int cuda/cpu to double on self._device
y_pred = y_pred.to(dtype=torch.float64, device=self._device)
y = y.to(dtype=torch.float64, device=self._device)
y_pred = y_pred.to(dtype=self._double_dtype, device=self._device)
y = y.to(dtype=self._double_dtype, device=self._device)
correct = y * y_pred

return y_pred, y, correct
Expand Down
4 changes: 2 additions & 2 deletions ignite/metrics/psnr.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def _check_shape_dtype(self, output: Sequence[torch.Tensor]) -> None:

@reinit__is_reduced
def reset(self) -> None:
self._sum_of_batchwise_psnr = torch.tensor(0.0, dtype=torch.float64, device=self._device)
self._sum_of_batchwise_psnr = torch.tensor(0.0, dtype=self._double_dtype, device=self._device)
self._num_examples = 0

@reinit__is_reduced
Expand All @@ -122,7 +122,7 @@ def update(self, output: Sequence[torch.Tensor]) -> None:
y_pred, y = output[0].detach(), output[1].detach()

dim = tuple(range(1, y.ndim))
mse_error = torch.pow(y_pred.double() - y.view_as(y_pred).double(), 2).mean(dim=dim)
mse_error = torch.pow(y_pred.to(self._double_dtype) - y.view_as(y_pred).to(self._double_dtype), 2).mean(dim=dim)
self._sum_of_batchwise_psnr += torch.sum(10.0 * torch.log10(self.data_range**2 / (mse_error + 1e-10))).to(
device=self._device
)
Expand Down
25 changes: 18 additions & 7 deletions ignite/metrics/ssim.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,11 +116,11 @@ def __init__(

@reinit__is_reduced
def reset(self) -> None:
self._sum_of_ssim = torch.tensor(0.0, dtype=torch.float64, device=self._device)
self._sum_of_ssim = torch.tensor(0.0, dtype=self._double_dtype, device=self._device)
self._num_examples = 0

def _uniform(self, kernel_size: int) -> torch.Tensor:
kernel = torch.zeros(kernel_size)
kernel = torch.zeros(kernel_size, device=self._device)

start_uniform_index = max(kernel_size // 2 - 2, 0)
end_uniform_index = min(kernel_size // 2 + 3, kernel_size)
Expand All @@ -146,10 +146,7 @@ def _gaussian_or_uniform_kernel(self, kernel_size: Sequence[int], sigma: Sequenc

return torch.matmul(kernel_x.t(), kernel_y) # (kernel_size, 1) * (1, kernel_size)

@reinit__is_reduced
def update(self, output: Sequence[torch.Tensor]) -> None:
y_pred, y = output[0].detach(), output[1].detach()

def _check_type_and_shape(self, y_pred: torch.Tensor, y: torch.Tensor) -> None:
if y_pred.dtype != y.dtype:
raise TypeError(
f"Expected y_pred and y to have the same data type. Got y_pred: {y_pred.dtype} and y: {y.dtype}."
Expand All @@ -165,6 +162,12 @@ def update(self, output: Sequence[torch.Tensor]) -> None:
f"Expected y_pred and y to have BxCxHxW shape. Got y_pred: {y_pred.shape} and y: {y.shape}."
)

@reinit__is_reduced
def update(self, output: Sequence[torch.Tensor]) -> None:
y_pred, y = output[0].detach(), output[1].detach()

self._check_type_and_shape(y_pred, y)

# converts potential integer tensor to fp
if not y.is_floating_point():
y = y.float()
Expand Down Expand Up @@ -213,7 +216,15 @@ def update(self, output: Sequence[torch.Tensor]) -> None:
b2 = sigma_pred_sq + sigma_target_sq + self.c2

ssim_idx = (a1 * a2) / (b1 * b2)
self._sum_of_ssim += torch.mean(ssim_idx, (1, 2, 3), dtype=torch.float64).sum().to(device=self._device)

# In case when ssim_idx can be MPS tensor and self._device is not MPS
# self._double_dtype is float64.
# As MPS does not support float64 we should set dtype to float32
double_dtype = self._double_dtype
if ssim_idx.device.type == "mps" and self._double_dtype == torch.float64:
double_dtype = torch.float32

self._sum_of_ssim += torch.mean(ssim_idx, (1, 2, 3), dtype=double_dtype).sum().to(device=self._device)

self._num_examples += y.shape[0]

Expand Down
24 changes: 15 additions & 9 deletions tests/ignite/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,19 +69,25 @@ def term_handler():
yield # Just pass through if SIGTERM isn't supported or we are not in the main thread


@pytest.fixture(
params=[
"cpu",
pytest.param("cuda", marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="Skip if no CUDA support")),
pytest.param(
"mps", marks=pytest.mark.skipif(not is_mps_available_and_functional(), reason="Skip if no MPS support")
),
]
)
available_devices_list = [
"cpu",
pytest.param("cuda", marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="Skip if no CUDA support")),
pytest.param(
"mps", marks=pytest.mark.skipif(not is_mps_available_and_functional(), reason="Skip if no MPS support")
),
]


@pytest.fixture(params=available_devices_list)
def available_device(request):
return request.param


@pytest.fixture(params=available_devices_list)
def available_device2(request):
return request.param


@pytest.fixture()
def dirname():
path = Path(tempfile.mkdtemp())
Expand Down
35 changes: 22 additions & 13 deletions tests/ignite/metrics/gan/test_fid.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,18 +53,23 @@ def test_fid_function():
)


def test_compute_fid_from_features():
def test_compute_fid_from_features(available_device):
train_samples, test_samples = torch.rand(10, 10), torch.rand(10, 10)

fid_scorer = FID(num_features=10, feature_extractor=torch.nn.Identity())
fid_scorer = FID(
num_features=10,
feature_extractor=torch.nn.Identity(),
device=available_device,
)
fid_scorer.update([train_samples[:5], test_samples[:5]])
fid_scorer.update([train_samples[5:], test_samples[5:]])

mu1, sigma1 = train_samples.mean(axis=0), cov(train_samples, rowvar=False)
mu2, sigma2 = test_samples.mean(axis=0), cov(test_samples, rowvar=False)

tol = 1e-4 if available_device == "mps" else 1e-5
assert (
pytest.approx(pytorch_fid_score.calculate_frechet_distance(mu1, sigma1, mu2, sigma2), rel=1e-5)
pytest.approx(pytorch_fid_score.calculate_frechet_distance(mu1, sigma1, mu2, sigma2), rel=tol)
== fid_scorer.compute()
)

Expand Down Expand Up @@ -128,28 +133,32 @@ def test_wrong_inputs():
FID(feature_extractor=torch.nn.Identity())


def test_statistics():
def test_statistics(available_device):
train_samples, test_samples = torch.rand(10, 10), torch.rand(10, 10)
fid_scorer = FID(num_features=10, feature_extractor=torch.nn.Identity())
fid_scorer = FID(
num_features=10,
feature_extractor=torch.nn.Identity(),
device=available_device,
)
fid_scorer.update([train_samples[:5], test_samples[:5]])
fid_scorer.update([train_samples[5:], test_samples[5:]])

mu1, sigma1 = train_samples.mean(axis=0), torch.tensor(cov(train_samples, rowvar=False))
mu2, sigma2 = test_samples.mean(axis=0), torch.tensor(cov(test_samples, rowvar=False))
mu1 = train_samples.mean(axis=0, dtype=torch.float64)
sigma1 = torch.tensor(cov(train_samples, rowvar=False), dtype=torch.float64)
mu2 = test_samples.mean(axis=0, dtype=torch.float64)
sigma2 = torch.tensor(cov(test_samples, rowvar=False), dtype=torch.float64)

fid_mu1 = fid_scorer._train_total / fid_scorer._num_examples
fid_sigma1 = fid_scorer._get_covariance(fid_scorer._train_sigma, fid_scorer._train_total)

fid_mu2 = fid_scorer._test_total / fid_scorer._num_examples
fid_sigma2 = fid_scorer._get_covariance(fid_scorer._test_sigma, fid_scorer._test_total)

assert torch.isclose(mu1.double(), fid_mu1).all()
for cov1, cov2 in zip(sigma1, fid_sigma1):
assert torch.isclose(cov1.double(), cov2, rtol=1e-04, atol=1e-04).all()
assert torch.allclose(mu1, fid_mu1.cpu().to(dtype=mu1.dtype))
assert torch.allclose(sigma1, fid_sigma1.cpu().to(dtype=sigma1.dtype), rtol=1e-04, atol=1e-04)

assert torch.isclose(mu2.double(), fid_mu2).all()
for cov1, cov2 in zip(sigma2, fid_sigma2):
assert torch.isclose(cov1.double(), cov2, rtol=1e-04, atol=1e-04).all()
assert torch.allclose(mu2, fid_mu2.cpu().to(dtype=mu2.dtype))
assert torch.allclose(sigma2, fid_sigma2.cpu().to(dtype=mu2.dtype), rtol=1e-04, atol=1e-04)


def _test_distrib_integration(device):
Expand Down
10 changes: 7 additions & 3 deletions tests/ignite/metrics/gan/test_inception_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,18 @@ def calculate_inception_score(p_yx):
return is_score


def test_inception_score():
def test_inception_score(available_device):
p_yx = torch.rand(20, 10)
m = InceptionScore(num_features=10, feature_extractor=torch.nn.Identity())
m = InceptionScore(
num_features=10,
feature_extractor=torch.nn.Identity(),
device=available_device,
)
m.update(p_yx)
assert pytest.approx(calculate_inception_score(p_yx)) == m.compute()

p_yx = torch.rand(20, 3, 299, 299)
m = InceptionScore()
m = InceptionScore(device=available_device)
m.update(p_yx)
assert isinstance(m.compute(), float)

Expand Down
4 changes: 2 additions & 2 deletions tests/ignite/metrics/test_average_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def test_check_shape():
ap._check_shape((torch.rand(4, 3), torch.rand(4, 3, 1)))


@pytest.fixture(params=[item for item in range(8)])
@pytest.fixture(params=range(8))
def test_data_binary_and_multilabel(request):
return [
# Binary input data of shape (N,) or (N, 1)
Expand Down Expand Up @@ -102,7 +102,7 @@ def test_binary_and_multilabel_inputs(n_times, test_data_binary_and_multilabel):
assert average_precision_score(np_y, np_y_pred) == pytest.approx(res)


@pytest.fixture(params=[item for item in range(4)])
@pytest.fixture(params=range(4))
def test_data_integration_binary_and_multilabel(request):
return [
# Binary input data of shape (N,) or (N, 1)
Expand Down
Loading
Loading