Don't hardcode the device id to 'cuda', instead use what the model is configured to

This commit is contained in:
cmdr2 2022-12-16 16:24:09 +05:30
parent cc77f2300d
commit 3aad790367
3 changed files with 6 additions and 6 deletions

View file

@ -16,8 +16,8 @@ class DDIMSampler(object):
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != torch.device("cuda"):
attr = attr.to(torch.device("cuda"))
if attr.device != torch.device(self.model.device):
attr = attr.to(torch.device(self.model.device))
setattr(self, name, attr)
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):

View file

@ -19,8 +19,8 @@ class DPMSolverSampler(object):
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != torch.device("cuda"):
attr = attr.to(torch.device("cuda"))
if attr.device != torch.device(self.model.device):
attr = attr.to(torch.device(self.model.device))
setattr(self, name, attr)
@torch.no_grad()

View file

@ -18,8 +18,8 @@ class PLMSSampler(object):
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != torch.device("cuda"):
attr = attr.to(torch.device("cuda"))
if attr.device != torch.device(self.model.device):
attr = attr.to(torch.device(self.model.device))
setattr(self, name, attr)
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):