From 3aad79036736e92a3bac11a53dd0b700a988747e Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Fri, 16 Dec 2022 16:24:09 +0530 Subject: [PATCH] Don't hardcode the device id to 'cuda', instead use what the model is configured to --- ldm/models/diffusion/ddim.py | 4 ++-- ldm/models/diffusion/dpm_solver/sampler.py | 4 ++-- ldm/models/diffusion/plms.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/ldm/models/diffusion/ddim.py b/ldm/models/diffusion/ddim.py index 27ead0e..7aeb7bc 100644 --- a/ldm/models/diffusion/ddim.py +++ b/ldm/models/diffusion/ddim.py @@ -16,8 +16,8 @@ class DDIMSampler(object): def register_buffer(self, name, attr): if type(attr) == torch.Tensor: - if attr.device != torch.device("cuda"): - attr = attr.to(torch.device("cuda")) + if attr.device != torch.device(self.model.device): + attr = attr.to(torch.device(self.model.device)) setattr(self, name, attr) def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): diff --git a/ldm/models/diffusion/dpm_solver/sampler.py b/ldm/models/diffusion/dpm_solver/sampler.py index 7d137b8..4a3e2ac 100644 --- a/ldm/models/diffusion/dpm_solver/sampler.py +++ b/ldm/models/diffusion/dpm_solver/sampler.py @@ -19,8 +19,8 @@ class DPMSolverSampler(object): def register_buffer(self, name, attr): if type(attr) == torch.Tensor: - if attr.device != torch.device("cuda"): - attr = attr.to(torch.device("cuda")) + if attr.device != torch.device(self.model.device): + attr = attr.to(torch.device(self.model.device)) setattr(self, name, attr) @torch.no_grad() diff --git a/ldm/models/diffusion/plms.py b/ldm/models/diffusion/plms.py index 7002a36..5428d1e 100644 --- a/ldm/models/diffusion/plms.py +++ b/ldm/models/diffusion/plms.py @@ -18,8 +18,8 @@ class PLMSSampler(object): def register_buffer(self, name, attr): if type(attr) == torch.Tensor: - if attr.device != torch.device("cuda"): - attr = attr.to(torch.device("cuda")) + if attr.device != torch.device(self.model.device): + attr = attr.to(torch.device(self.model.device)) setattr(self, name, attr) def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):