Switch from no_grad to inference mode for scripts

This commit is contained in:
kjerk 2022-11-24 01:48:48 -08:00
parent 33910c386e
commit e6050f3e58
8 changed files with 18 additions and 18 deletions

View file

@ -64,8 +64,7 @@ def paint(sampler, image, prompt, t_enc, seed, scale, num_samples=1, callback=No
wm_encoder = WatermarkEncoder() wm_encoder = WatermarkEncoder()
wm_encoder.set_watermark('bytes', wm.encode('utf-8')) wm_encoder.set_watermark('bytes', wm.encode('utf-8'))
with torch.no_grad(),\ with torch.inference_mode(), torch.autocast("cuda"):
torch.autocast("cuda"):
batch = make_batch_sd( batch = make_batch_sd(
image, txt=prompt, device=device, num_samples=num_samples) image, txt=prompt, device=device, num_samples=num_samples)
z = model.get_first_stage_encoding(model.encode_first_stage( z = model.get_first_stage_encoding(model.encode_first_stage(

View file

@ -81,8 +81,7 @@ def inpaint(sampler, image, mask, prompt, seed, scale, ddim_steps, num_samples=1
start_code = torch.from_numpy(start_code).to( start_code = torch.from_numpy(start_code).to(
device=device, dtype=torch.float32) device=device, dtype=torch.float32)
with torch.no_grad(), \ with torch.inference_mode(), torch.autocast("cuda"):
torch.autocast("cuda"):
batch = make_batch_sd(image, mask, txt=prompt, batch = make_batch_sd(image, mask, txt=prompt,
device=device, num_samples=num_samples) device=device, num_samples=num_samples)

View file

@ -67,8 +67,8 @@ def paint(sampler, image, prompt, seed, scale, h, w, steps, num_samples=1, callb
wm = "SDV2" wm = "SDV2"
wm_encoder = WatermarkEncoder() wm_encoder = WatermarkEncoder()
wm_encoder.set_watermark('bytes', wm.encode('utf-8')) wm_encoder.set_watermark('bytes', wm.encode('utf-8'))
with torch.no_grad(),\
torch.autocast("cuda"): with torch.inference_mode(), torch.autocast("cuda"):
batch = make_batch_sd( batch = make_batch_sd(
image, txt=prompt, device=device, num_samples=num_samples) image, txt=prompt, device=device, num_samples=num_samples)
c = model.cond_stage_model.encode(batch["txt"]) c = model.cond_stage_model.encode(batch["txt"])
@ -112,8 +112,10 @@ def paint(sampler, image, prompt, seed, scale, h, w, steps, num_samples=1, callb
x_T=start_code, x_T=start_code,
callback=callback callback=callback
) )
with torch.no_grad():
with torch.inference_mode():
x_samples_ddim = model.decode_first_stage(samples) x_samples_ddim = model.decode_first_stage(samples)
result = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) result = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
result = result.cpu().numpy().transpose(0, 2, 3, 1) * 255 result = result.cpu().numpy().transpose(0, 2, 3, 1) * 255
return [put_watermark(Image.fromarray(img.astype(np.uint8)), wm_encoder) for img in result] return [put_watermark(Image.fromarray(img.astype(np.uint8)), wm_encoder) for img in result]

View file

@ -230,7 +230,8 @@ def main():
print(f"target t_enc is {t_enc} steps") print(f"target t_enc is {t_enc} steps")
precision_scope = autocast if opt.precision == "autocast" else nullcontext precision_scope = autocast if opt.precision == "autocast" else nullcontext
with torch.no_grad():
with torch.inference_mode():
with precision_scope("cuda"): with precision_scope("cuda"):
with model.ema_scope(): with model.ema_scope():
all_samples = list() all_samples = list()

View file

@ -61,8 +61,7 @@ def paint(sampler, image, prompt, t_enc, seed, scale, num_samples=1, callback=No
wm_encoder = WatermarkEncoder() wm_encoder = WatermarkEncoder()
wm_encoder.set_watermark('bytes', wm.encode('utf-8')) wm_encoder.set_watermark('bytes', wm.encode('utf-8'))
with torch.no_grad(),\ with torch.inference_mode(), torch.autocast("cuda"):
torch.autocast("cuda"):
batch = make_batch_sd(image, txt=prompt, device=device, num_samples=num_samples) batch = make_batch_sd(image, txt=prompt, device=device, num_samples=num_samples)
z = model.get_first_stage_encoding(model.encode_first_stage(batch[model.first_stage_key])) # move to latent space z = model.get_first_stage_encoding(model.encode_first_stage(batch[model.first_stage_key])) # move to latent space
c = model.cond_stage_model.encode(batch["txt"]) c = model.cond_stage_model.encode(batch["txt"])

View file

@ -79,8 +79,7 @@ def inpaint(sampler, image, mask, prompt, seed, scale, ddim_steps, num_samples=1
start_code = prng.randn(num_samples, 4, h // 8, w // 8) start_code = prng.randn(num_samples, 4, h // 8, w // 8)
start_code = torch.from_numpy(start_code).to(device=device, dtype=torch.float32) start_code = torch.from_numpy(start_code).to(device=device, dtype=torch.float32)
with torch.no_grad(), \ with torch.inference_mode(), torch.autocast("cuda"):
torch.autocast("cuda"):
batch = make_batch_sd(image, mask, txt=prompt, device=device, num_samples=num_samples) batch = make_batch_sd(image, mask, txt=prompt, device=device, num_samples=num_samples)
c = model.cond_stage_model.encode(batch["txt"]) c = model.cond_stage_model.encode(batch["txt"])

View file

@ -64,8 +64,8 @@ def paint(sampler, image, prompt, seed, scale, h, w, steps, num_samples=1, callb
wm = "SDV2" wm = "SDV2"
wm_encoder = WatermarkEncoder() wm_encoder = WatermarkEncoder()
wm_encoder.set_watermark('bytes', wm.encode('utf-8')) wm_encoder.set_watermark('bytes', wm.encode('utf-8'))
with torch.no_grad(),\
torch.autocast("cuda"): with torch.inference_mode(), torch.autocast("cuda"):
batch = make_batch_sd(image, txt=prompt, device=device, num_samples=num_samples) batch = make_batch_sd(image, txt=prompt, device=device, num_samples=num_samples)
c = model.cond_stage_model.encode(batch["txt"]) c = model.cond_stage_model.encode(batch["txt"])
c_cat = list() c_cat = list()
@ -105,8 +105,10 @@ def paint(sampler, image, prompt, seed, scale, h, w, steps, num_samples=1, callb
x_T=start_code, x_T=start_code,
callback=callback callback=callback
) )
with torch.no_grad():
with torch.inference_mode():
x_samples_ddim = model.decode_first_stage(samples) x_samples_ddim = model.decode_first_stage(samples)
result = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) result = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
result = result.cpu().numpy().transpose(0, 2, 3, 1) * 255 result = result.cpu().numpy().transpose(0, 2, 3, 1) * 255
st.text(f"upscaled image shape: {result.shape}") st.text(f"upscaled image shape: {result.shape}")

View file

@ -232,9 +232,8 @@ def main(opt):
start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device) start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device)
precision_scope = autocast if opt.precision == "autocast" else nullcontext precision_scope = autocast if opt.precision == "autocast" else nullcontext
with torch.no_grad(), \
precision_scope("cuda"), \ with torch.inference_mode(), precision_scope("cuda"), model.ema_scope():
model.ema_scope():
all_samples = list() all_samples = list()
for n in trange(opt.n_iter, desc="Sampling"): for n in trange(opt.n_iter, desc="Sampling"):
for prompts in tqdm(data, desc="data"): for prompts in tqdm(data, desc="data"):