From 054294dfb34deb6f8cf6c9f6634346d8f384f717 Mon Sep 17 00:00:00 2001 From: Andres Caicedo Date: Mon, 8 May 2023 13:02:41 +0200 Subject: [PATCH] Optimize load_model_from_config function Optimize model loading function --- scripts/txt2img.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/scripts/txt2img.py b/scripts/txt2img.py index 9d955e3..adc1633 100644 --- a/scripts/txt2img.py +++ b/scripts/txt2img.py @@ -25,7 +25,10 @@ def chunk(it, size): return iter(lambda: tuple(islice(it, size)), ()) -def load_model_from_config(config, ckpt, device=torch.device("cuda"), verbose=False): +def load_model_from_config(config, ckpt, device="cuda", verbose=False): + if device not in {"cuda", "cpu"}: + raise ValueError(f"Incorrect device name. Received: {device}") + print(f"Loading model from {ckpt}") pl_sd = torch.load(ckpt, map_location="cpu") if "global_step" in pl_sd: @@ -40,14 +43,11 @@ def load_model_from_config(config, ckpt, device=torch.device("cuda"), verbose=Fa print("unexpected keys:") print(u) - if device == torch.device("cuda"): - model.cuda() - elif device == torch.device("cpu"): - model.cpu() - model.cond_stage_model.device = "cpu" - else: - raise ValueError(f"Incorrect device name. Received: {device}") - model.eval() + with torch.cuda.device(device) if device == "cuda" else torch.no_grad(): + if device == "cpu": + model.cond_stage_model.device = "cpu" + model.to(device) + model.eval() return model