Skip to content

Metrics computation on pytorch MPS backend fails, needs float32 instead of float64 #3326

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
blackplane opened this issue Feb 3, 2025 · 5 comments · Fixed by #3334
Closed

Comments

@blackplane
Copy link

🐛 Bug description

Running metrics via evaluator.run(dataloader) on MacOS fails, because the pytorch MPS backend doesn't support the float64 type that the result is cast into. source code link

Suggestion: Cast to float32 instead. Although it is correct that int64 cannot cast to float32 without a loss of precision and float64 is technically the correct choice, pragmatically I'd think that this loss won't matter in practice as y_pred represents class label and 2^32 are already more than anyone will need.

File "/Users/user/Library/Caches/pypoetry/virtualenvs/columbo-HtoL9iYn-py3.12/lib/python3.12/site-packages/ignite/engine/engine.py", line 889, in run
return self._internal_run()
^^^^^^^^^^^^^^^^^^^^
File "/Users/user/Library/Caches/pypoetry/virtualenvs/columbo-HtoL9iYn-py3.12/lib/python3.12/site-packages/ignite/engine/engine.py", line 932, in _internal_run
return next(self._internal_run_generator)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/user/Library/Caches/pypoetry/virtualenvs/columbo-HtoL9iYn-py3.12/lib/python3.12/site-packages/ignite/engine/engine.py", line 990, in _internal_run_as_gen
self._handle_exception(e)
File "/Users/user/Library/Caches/pypoetry/virtualenvs/columbo-HtoL9iYn-py3.12/lib/python3.12/site-packages/ignite/engine/engine.py", line 644, in _handle_exception
raise e
File "/Users/user/Library/Caches/pypoetry/virtualenvs/columbo-HtoL9iYn-py3.12/lib/python3.12/site-packages/ignite/engine/engine.py", line 956, in _internal_run_as_gen
epoch_time_taken += yield from self._run_once_on_dataset_as_gen()
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/user/Library/Caches/pypoetry/virtualenvs/columbo-HtoL9iYn-py3.12/lib/python3.12/site-packages/ignite/engine/engine.py", line 1096, in _run_once_on_dataset_as_gen
self._handle_exception(e)
File "/Users/user/Library/Caches/pypoetry/virtualenvs/columbo-HtoL9iYn-py3.12/lib/python3.12/site-packages/ignite/engine/engine.py", line 644, in _handle_exception
raise e
File "/Users/user/Library/Caches/pypoetry/virtualenvs/columbo-HtoL9iYn-py3.12/lib/python3.12/site-packages/ignite/engine/engine.py", line 1078, in _run_once_on_dataset_as_gen
self._fire_event(Events.ITERATION_COMPLETED)
File "/Users/user/Library/Caches/pypoetry/virtualenvs/columbo-HtoL9iYn-py3.12/lib/python3.12/site-packages/ignite/engine/engine.py", line 431, in _fire_event
func(*first, *(event_args + others), **kwargs)
File "/Users/user/Library/Caches/pypoetry/virtualenvs/columbo-HtoL9iYn-py3.12/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/Users/user/Library/Caches/pypoetry/virtualenvs/columbo-HtoL9iYn-py3.12/lib/python3.12/site-packages/ignite/metrics/metric.py", line 469, in iteration_completed
self.update(output)
File "/Users/user/Library/Caches/pypoetry/virtualenvs/columbo-HtoL9iYn-py3.12/lib/python3.12/site-packages/ignite/metrics/metric.py", line 864, in wrapper
func(self, *args, **kwargs)
File "/Users/user/Library/Caches/pypoetry/virtualenvs/columbo-HtoL9iYn-py3.12/lib/python3.12/site-packages/ignite/metrics/recall.py", line 227, in update
_, y, correct = self._prepare_output(output)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/user/Library/Caches/pypoetry/virtualenvs/columbo-HtoL9iYn-py3.12/lib/python3.12/site-packages/ignite/metrics/precision.py", line 84, in _prepare_output
y_pred = y_pred.to(dtype=torch.float64, device=self._device)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.

Environment

  • PyTorch Version (e.g., 1.4): 2.6.0
  • Ignite Version (e.g., 0.3.0): 0.5.1
  • OS (e.g., Linux): MacOS
  • How you installed Ignite (conda, pip, source): pip
  • Python version: 3.12.8
  • Any other relevant information:
@blackplane blackplane changed the title Metrics computation on MPS device fails, needs float32 instead of float64 Metrics computation on pytorch MPS backend fails, needs float32 instead of float64 Feb 3, 2025
@vfdev-5
Copy link
Collaborator

vfdev-5 commented Feb 3, 2025

@blackplane thanks for the report!

A question, do you explicitly specify mps device in the metric definition?

precision = Precision(device="mps")

It is possible to add a route and use f32 for MPS device in the code where we set f64 dtype for computations and probably increase the tolerance in the tests.

@dmalinverni
Copy link

Hi,
Bumping this up, I'm facing the same problem (also using the Precision() metric on MPS backend).

Using
precision = Precision(device="mps")
didn't change the problem for me.

Thanks for the support and the great library!

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Feb 26, 2025

@dmalinverni can you please provide the following info:

  • pytorch version
  • traceback of the failure
  • if you set device to "cpu", does it pass ?
    Thanks!

@dmalinverni
Copy link

Sure, here you go

  • pytorch 2.3.1
  • ignite 0.5.1
  • The code runs normally with device='cpu'.
  • Relevant part of the traceback:

File "[...]/training.py", line 55, in train_one_epoch
epoch_prec.update((pred_logits>0,y))
File "[...]site-packages/ignite/metrics/metric.py", line 864, in wrapper
func(self, *args, **kwargs)
File "[...]site-packages/ignite/metrics/precision.py", line 414, in update
y_pred, y, correct = self._prepare_output(output)
File "[...]site-packages/ignite/metrics/precision.py", line 84, in _prepare_output
y_pred = y_pred.to(dtype=torch.float64, device=self._device)
TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.

vfdev-5 added a commit that referenced this issue Feb 26, 2025
vfdev-5 added a commit that referenced this issue Feb 27, 2025
vfdev-5 added a commit that referenced this issue Feb 28, 2025
* Use float32 in metrics when metric device is MPS

Related to #3326

* Added mps f64 -> f32 cast to Metric class and applied in other metrics + new tests

* Fixed failing tests

* Update torch installation in mps-tests.yml

* Fix test_fid.py::test_statistics
@vfdev-5
Copy link
Collaborator

vfdev-5 commented Feb 28, 2025

@blackplane @dmalinverni I made a fix for MPS for Precision, Recall and few other metrics which use float64.
Feel free to make your tests with ignite from master or nightly release (since tomorrow) and let me know if it works from your side.

@vfdev-5 vfdev-5 mentioned this issue Mar 23, 2025
27 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants