mirror of
https://github.com/Stability-AI/stablediffusion.git
synced 2024-12-22 07:34:58 +00:00
Optimize load_model_from_config function
Optimize model loading function
This commit is contained in:
parent
cf1d67a6fd
commit
054294dfb3
1 changed files with 9 additions and 9 deletions
|
@ -25,7 +25,10 @@ def chunk(it, size):
|
|||
return iter(lambda: tuple(islice(it, size)), ())
|
||||
|
||||
|
||||
def load_model_from_config(config, ckpt, device=torch.device("cuda"), verbose=False):
|
||||
def load_model_from_config(config, ckpt, device="cuda", verbose=False):
|
||||
if device not in {"cuda", "cpu"}:
|
||||
raise ValueError(f"Incorrect device name. Received: {device}")
|
||||
|
||||
print(f"Loading model from {ckpt}")
|
||||
pl_sd = torch.load(ckpt, map_location="cpu")
|
||||
if "global_step" in pl_sd:
|
||||
|
@ -40,14 +43,11 @@ def load_model_from_config(config, ckpt, device=torch.device("cuda"), verbose=Fa
|
|||
print("unexpected keys:")
|
||||
print(u)
|
||||
|
||||
if device == torch.device("cuda"):
|
||||
model.cuda()
|
||||
elif device == torch.device("cpu"):
|
||||
model.cpu()
|
||||
model.cond_stage_model.device = "cpu"
|
||||
else:
|
||||
raise ValueError(f"Incorrect device name. Received: {device}")
|
||||
model.eval()
|
||||
with torch.cuda.device(device) if device == "cuda" else torch.no_grad():
|
||||
if device == "cpu":
|
||||
model.cond_stage_model.device = "cpu"
|
||||
model.to(device)
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
|
|
Loading…
Reference in a new issue