diff --git a/ldm/models/diffusion/ddim.py b/ldm/models/diffusion/ddim.py index 27ead0e..7aeb7bc 100644 --- a/ldm/models/diffusion/ddim.py +++ b/ldm/models/diffusion/ddim.py @@ -16,8 +16,8 @@ class DDIMSampler(object): def register_buffer(self, name, attr): if type(attr) == torch.Tensor: - if attr.device != torch.device("cuda"): - attr = attr.to(torch.device("cuda")) + if attr.device != torch.device(self.model.device): + attr = attr.to(torch.device(self.model.device)) setattr(self, name, attr) def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): diff --git a/ldm/models/diffusion/dpm_solver/sampler.py b/ldm/models/diffusion/dpm_solver/sampler.py index 7d137b8..4a3e2ac 100644 --- a/ldm/models/diffusion/dpm_solver/sampler.py +++ b/ldm/models/diffusion/dpm_solver/sampler.py @@ -19,8 +19,8 @@ class DPMSolverSampler(object): def register_buffer(self, name, attr): if type(attr) == torch.Tensor: - if attr.device != torch.device("cuda"): - attr = attr.to(torch.device("cuda")) + if attr.device != torch.device(self.model.device): + attr = attr.to(torch.device(self.model.device)) setattr(self, name, attr) @torch.no_grad() diff --git a/ldm/models/diffusion/plms.py b/ldm/models/diffusion/plms.py index 7002a36..5428d1e 100644 --- a/ldm/models/diffusion/plms.py +++ b/ldm/models/diffusion/plms.py @@ -18,8 +18,8 @@ class PLMSSampler(object): def register_buffer(self, name, attr): if type(attr) == torch.Tensor: - if attr.device != torch.device("cuda"): - attr = attr.to(torch.device("cuda")) + if attr.device != torch.device(self.model.device): + attr = attr.to(torch.device(self.model.device)) setattr(self, name, attr) def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):