Separate mps and other processes

This commit is contained in:
Ftps 2023-02-09 16:38:18 +09:00
parent 19697bbf07
commit a101942e8a

View file

@ -18,7 +18,10 @@ class DDIMSampler(object):
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != self.device:
if str(self.device) == 'mps':
attr = attr.to(self.device, torch.float32)
else:
attr = attr.to(self.device)
setattr(self, name, attr)
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):