Optimize load_model_from_config function

Optimize model loading function
This commit is contained in:
Andres Caicedo 2023-05-08 13:02:41 +02:00
parent cf1d67a6fd
commit 054294dfb3
No known key found for this signature in database
GPG key ID: 6E797C4F5A327624

View file

@ -25,7 +25,10 @@ def chunk(it, size):
return iter(lambda: tuple(islice(it, size)), ())
def load_model_from_config(config, ckpt, device=torch.device("cuda"), verbose=False):
def load_model_from_config(config, ckpt, device="cuda", verbose=False):
if device not in {"cuda", "cpu"}:
raise ValueError(f"Incorrect device name. Received: {device}")
print(f"Loading model from {ckpt}")
pl_sd = torch.load(ckpt, map_location="cpu")
if "global_step" in pl_sd:
@ -40,13 +43,10 @@ def load_model_from_config(config, ckpt, device=torch.device("cuda"), verbose=Fa
print("unexpected keys:")
print(u)
if device == torch.device("cuda"):
model.cuda()
elif device == torch.device("cpu"):
model.cpu()
with torch.cuda.device(device) if device == "cuda" else torch.no_grad():
if device == "cpu":
model.cond_stage_model.device = "cpu"
else:
raise ValueError(f"Incorrect device name. Received: {device}")
model.to(device)
model.eval()
return model