From e6050f3e58380f7e2c5086b4e98c58289c9f56b9 Mon Sep 17 00:00:00 2001 From: kjerk Date: Thu, 24 Nov 2022 01:48:48 -0800 Subject: [PATCH] Switch from no_grad to inference mode for scripts --- scripts/gradio/depth2img.py | 3 +-- scripts/gradio/inpainting.py | 3 +-- scripts/gradio/superresolution.py | 8 +++++--- scripts/img2img.py | 3 ++- scripts/streamlit/depth2img.py | 3 +-- scripts/streamlit/inpainting.py | 3 +-- scripts/streamlit/superresolution.py | 8 +++++--- scripts/txt2img.py | 5 ++--- 8 files changed, 18 insertions(+), 18 deletions(-) diff --git a/scripts/gradio/depth2img.py b/scripts/gradio/depth2img.py index c791a4d..a77aac9 100644 --- a/scripts/gradio/depth2img.py +++ b/scripts/gradio/depth2img.py @@ -64,8 +64,7 @@ def paint(sampler, image, prompt, t_enc, seed, scale, num_samples=1, callback=No wm_encoder = WatermarkEncoder() wm_encoder.set_watermark('bytes', wm.encode('utf-8')) - with torch.no_grad(),\ - torch.autocast("cuda"): + with torch.inference_mode(), torch.autocast("cuda"): batch = make_batch_sd( image, txt=prompt, device=device, num_samples=num_samples) z = model.get_first_stage_encoding(model.encode_first_stage( diff --git a/scripts/gradio/inpainting.py b/scripts/gradio/inpainting.py index 09d44f3..147ea31 100644 --- a/scripts/gradio/inpainting.py +++ b/scripts/gradio/inpainting.py @@ -81,8 +81,7 @@ def inpaint(sampler, image, mask, prompt, seed, scale, ddim_steps, num_samples=1 start_code = torch.from_numpy(start_code).to( device=device, dtype=torch.float32) - with torch.no_grad(), \ - torch.autocast("cuda"): + with torch.inference_mode(), torch.autocast("cuda"): batch = make_batch_sd(image, mask, txt=prompt, device=device, num_samples=num_samples) diff --git a/scripts/gradio/superresolution.py b/scripts/gradio/superresolution.py index 3d08fbf..dd3d00e 100644 --- a/scripts/gradio/superresolution.py +++ b/scripts/gradio/superresolution.py @@ -67,8 +67,8 @@ def paint(sampler, image, prompt, seed, scale, h, w, steps, num_samples=1, callb wm = "SDV2" wm_encoder = WatermarkEncoder() wm_encoder.set_watermark('bytes', wm.encode('utf-8')) - with torch.no_grad(),\ - torch.autocast("cuda"): + + with torch.inference_mode(), torch.autocast("cuda"): batch = make_batch_sd( image, txt=prompt, device=device, num_samples=num_samples) c = model.cond_stage_model.encode(batch["txt"]) @@ -112,8 +112,10 @@ def paint(sampler, image, prompt, seed, scale, h, w, steps, num_samples=1, callb x_T=start_code, callback=callback ) - with torch.no_grad(): + + with torch.inference_mode(): x_samples_ddim = model.decode_first_stage(samples) + result = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) result = result.cpu().numpy().transpose(0, 2, 3, 1) * 255 return [put_watermark(Image.fromarray(img.astype(np.uint8)), wm_encoder) for img in result] diff --git a/scripts/img2img.py b/scripts/img2img.py index 9085ba9..77df33a 100644 --- a/scripts/img2img.py +++ b/scripts/img2img.py @@ -230,7 +230,8 @@ def main(): print(f"target t_enc is {t_enc} steps") precision_scope = autocast if opt.precision == "autocast" else nullcontext - with torch.no_grad(): + + with torch.inference_mode(): with precision_scope("cuda"): with model.ema_scope(): all_samples = list() diff --git a/scripts/streamlit/depth2img.py b/scripts/streamlit/depth2img.py index 7f80223..8794835 100644 --- a/scripts/streamlit/depth2img.py +++ b/scripts/streamlit/depth2img.py @@ -61,8 +61,7 @@ def paint(sampler, image, prompt, t_enc, seed, scale, num_samples=1, callback=No wm_encoder = WatermarkEncoder() wm_encoder.set_watermark('bytes', wm.encode('utf-8')) - with torch.no_grad(),\ - torch.autocast("cuda"): + with torch.inference_mode(), torch.autocast("cuda"): batch = make_batch_sd(image, txt=prompt, device=device, num_samples=num_samples) z = model.get_first_stage_encoding(model.encode_first_stage(batch[model.first_stage_key])) # move to latent space c = model.cond_stage_model.encode(batch["txt"]) diff --git a/scripts/streamlit/inpainting.py b/scripts/streamlit/inpainting.py index c35772f..c6fb3b8 100644 --- a/scripts/streamlit/inpainting.py +++ b/scripts/streamlit/inpainting.py @@ -79,8 +79,7 @@ def inpaint(sampler, image, mask, prompt, seed, scale, ddim_steps, num_samples=1 start_code = prng.randn(num_samples, 4, h // 8, w // 8) start_code = torch.from_numpy(start_code).to(device=device, dtype=torch.float32) - with torch.no_grad(), \ - torch.autocast("cuda"): + with torch.inference_mode(), torch.autocast("cuda"): batch = make_batch_sd(image, mask, txt=prompt, device=device, num_samples=num_samples) c = model.cond_stage_model.encode(batch["txt"]) diff --git a/scripts/streamlit/superresolution.py b/scripts/streamlit/superresolution.py index c1172b0..87f4688 100644 --- a/scripts/streamlit/superresolution.py +++ b/scripts/streamlit/superresolution.py @@ -64,8 +64,8 @@ def paint(sampler, image, prompt, seed, scale, h, w, steps, num_samples=1, callb wm = "SDV2" wm_encoder = WatermarkEncoder() wm_encoder.set_watermark('bytes', wm.encode('utf-8')) - with torch.no_grad(),\ - torch.autocast("cuda"): + + with torch.inference_mode(), torch.autocast("cuda"): batch = make_batch_sd(image, txt=prompt, device=device, num_samples=num_samples) c = model.cond_stage_model.encode(batch["txt"]) c_cat = list() @@ -105,8 +105,10 @@ def paint(sampler, image, prompt, seed, scale, h, w, steps, num_samples=1, callb x_T=start_code, callback=callback ) - with torch.no_grad(): + + with torch.inference_mode(): x_samples_ddim = model.decode_first_stage(samples) + result = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) result = result.cpu().numpy().transpose(0, 2, 3, 1) * 255 st.text(f"upscaled image shape: {result.shape}") diff --git a/scripts/txt2img.py b/scripts/txt2img.py index 1ed42a3..0eb356e 100644 --- a/scripts/txt2img.py +++ b/scripts/txt2img.py @@ -232,9 +232,8 @@ def main(opt): start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device) precision_scope = autocast if opt.precision == "autocast" else nullcontext - with torch.no_grad(), \ - precision_scope("cuda"), \ - model.ema_scope(): + + with torch.inference_mode(), precision_scope("cuda"), model.ema_scope(): all_samples = list() for n in trange(opt.n_iter, desc="Sampling"): for prompts in tqdm(data, desc="data"):