mirror of
https://github.com/Stability-AI/stablediffusion.git
synced 2024-12-22 15:44:58 +00:00
Don't hardcode the device id to 'cuda', instead use what the model is configured to
This commit is contained in:
parent
cc77f2300d
commit
3aad790367
3 changed files with 6 additions and 6 deletions
|
@ -16,8 +16,8 @@ class DDIMSampler(object):
|
|||
|
||||
def register_buffer(self, name, attr):
|
||||
if type(attr) == torch.Tensor:
|
||||
if attr.device != torch.device("cuda"):
|
||||
attr = attr.to(torch.device("cuda"))
|
||||
if attr.device != torch.device(self.model.device):
|
||||
attr = attr.to(torch.device(self.model.device))
|
||||
setattr(self, name, attr)
|
||||
|
||||
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
|
||||
|
|
|
@ -19,8 +19,8 @@ class DPMSolverSampler(object):
|
|||
|
||||
def register_buffer(self, name, attr):
|
||||
if type(attr) == torch.Tensor:
|
||||
if attr.device != torch.device("cuda"):
|
||||
attr = attr.to(torch.device("cuda"))
|
||||
if attr.device != torch.device(self.model.device):
|
||||
attr = attr.to(torch.device(self.model.device))
|
||||
setattr(self, name, attr)
|
||||
|
||||
@torch.no_grad()
|
||||
|
|
|
@ -18,8 +18,8 @@ class PLMSSampler(object):
|
|||
|
||||
def register_buffer(self, name, attr):
|
||||
if type(attr) == torch.Tensor:
|
||||
if attr.device != torch.device("cuda"):
|
||||
attr = attr.to(torch.device("cuda"))
|
||||
if attr.device != torch.device(self.model.device):
|
||||
attr = attr.to(torch.device(self.model.device))
|
||||
setattr(self, name, attr)
|
||||
|
||||
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
|
||||
|
|
Loading…
Reference in a new issue