Optimize load_model_from_config function

Optimize model loading function
This commit is contained in:
Andres Caicedo 2023-05-08 13:02:41 +02:00
parent cf1d67a6fd
commit 054294dfb3
No known key found for this signature in database
GPG key ID: 6E797C4F5A327624

View file

@ -25,7 +25,10 @@ def chunk(it, size):
return iter(lambda: tuple(islice(it, size)), ()) return iter(lambda: tuple(islice(it, size)), ())
def load_model_from_config(config, ckpt, device=torch.device("cuda"), verbose=False): def load_model_from_config(config, ckpt, device="cuda", verbose=False):
if device not in {"cuda", "cpu"}:
raise ValueError(f"Incorrect device name. Received: {device}")
print(f"Loading model from {ckpt}") print(f"Loading model from {ckpt}")
pl_sd = torch.load(ckpt, map_location="cpu") pl_sd = torch.load(ckpt, map_location="cpu")
if "global_step" in pl_sd: if "global_step" in pl_sd:
@ -40,14 +43,11 @@ def load_model_from_config(config, ckpt, device=torch.device("cuda"), verbose=Fa
print("unexpected keys:") print("unexpected keys:")
print(u) print(u)
if device == torch.device("cuda"): with torch.cuda.device(device) if device == "cuda" else torch.no_grad():
model.cuda() if device == "cpu":
elif device == torch.device("cpu"): model.cond_stage_model.device = "cpu"
model.cpu() model.to(device)
model.cond_stage_model.device = "cpu" model.eval()
else:
raise ValueError(f"Incorrect device name. Received: {device}")
model.eval()
return model return model