Skip to content

Commit a078fe9

Browse files
committed
add cumsum workaround for newer PyTorch pytorch/pytorch#89784
1 parent cae3ca6 commit a078fe9

File tree

2 files changed

+15
-0
lines changed

2 files changed

+15
-0
lines changed

scripts/play.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,14 @@
22

33
# monkey-patch _randn to use CPU random before k-diffusion uses it
44
from helpers.brownian_tree_mps_fix import reassuring_message
5+
from helpers.cumsum_mps_fix import reassuring_message as reassuring_message_2
56
from helpers.device import DeviceLiteral, get_device_type
67
from helpers.diffusers_denoiser import DiffusersSDDenoiser, DiffusersSD2Denoiser
78
from helpers.cfg_denoiser import Denoiser, DenoiserFactory
89
from helpers.log_intermediates import LogIntermediates, make_log_intermediates
910
from helpers.schedules import KarrasScheduleParams, KarrasScheduleTemplate, get_template_schedule
1011
print(reassuring_message) # avoid "unused" import :P
12+
print(reassuring_message_2)
1113

1214
import torch
1315
from torch import Generator, Tensor, randn, no_grad, argmin, zeros

src/helpers/cumsum_mps_fix.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# monkey-patch cumsum to fallback to CPU
2+
# https://github.com/pytorch/pytorch/issues/89784
3+
import torch
4+
from torch import cumsum, Tensor
5+
torch.cumsum = lambda input, *args, **kwargs: (
6+
cumsum(input.cpu() if input.device.type == 'mps' else input, *args, **kwargs).to(input.device)
7+
)
8+
orig_cumsum = torch.Tensor.cumsum
9+
def patched_cumsum(self: Tensor, *args, **kwargs):
10+
return orig_cumsum(self.cpu() if self.device.type == 'mps' else self, *args, **kwargs).to(self.device)
11+
torch.Tensor.cumsum = patched_cumsum
12+
13+
reassuring_message = "monkey-patched cumsum to fallback to CPU, for compatibility on MPS backend; see https://github.com/pytorch/pytorch/issues/89784"

0 commit comments

Comments
 (0)