Switch from no_grad to inference mode for scripts

This commit is contained in:
kjerk 2022-11-24 01:48:48 -08:00
parent 33910c386e
commit e6050f3e58
8 changed files with 18 additions and 18 deletions

View file

@ -64,8 +64,7 @@ def paint(sampler, image, prompt, t_enc, seed, scale, num_samples=1, callback=No
wm_encoder = WatermarkEncoder()
wm_encoder.set_watermark('bytes', wm.encode('utf-8'))
with torch.no_grad(),\
torch.autocast("cuda"):
with torch.inference_mode(), torch.autocast("cuda"):
batch = make_batch_sd(
image, txt=prompt, device=device, num_samples=num_samples)
z = model.get_first_stage_encoding(model.encode_first_stage(

View file

@ -81,8 +81,7 @@ def inpaint(sampler, image, mask, prompt, seed, scale, ddim_steps, num_samples=1
start_code = torch.from_numpy(start_code).to(
device=device, dtype=torch.float32)
with torch.no_grad(), \
torch.autocast("cuda"):
with torch.inference_mode(), torch.autocast("cuda"):
batch = make_batch_sd(image, mask, txt=prompt,
device=device, num_samples=num_samples)

View file

@ -67,8 +67,8 @@ def paint(sampler, image, prompt, seed, scale, h, w, steps, num_samples=1, callb
wm = "SDV2"
wm_encoder = WatermarkEncoder()
wm_encoder.set_watermark('bytes', wm.encode('utf-8'))
with torch.no_grad(),\
torch.autocast("cuda"):
with torch.inference_mode(), torch.autocast("cuda"):
batch = make_batch_sd(
image, txt=prompt, device=device, num_samples=num_samples)
c = model.cond_stage_model.encode(batch["txt"])
@ -112,8 +112,10 @@ def paint(sampler, image, prompt, seed, scale, h, w, steps, num_samples=1, callb
x_T=start_code,
callback=callback
)
with torch.no_grad():
with torch.inference_mode():
x_samples_ddim = model.decode_first_stage(samples)
result = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
result = result.cpu().numpy().transpose(0, 2, 3, 1) * 255
return [put_watermark(Image.fromarray(img.astype(np.uint8)), wm_encoder) for img in result]

View file

@ -230,7 +230,8 @@ def main():
print(f"target t_enc is {t_enc} steps")
precision_scope = autocast if opt.precision == "autocast" else nullcontext
with torch.no_grad():
with torch.inference_mode():
with precision_scope("cuda"):
with model.ema_scope():
all_samples = list()

View file

@ -61,8 +61,7 @@ def paint(sampler, image, prompt, t_enc, seed, scale, num_samples=1, callback=No
wm_encoder = WatermarkEncoder()
wm_encoder.set_watermark('bytes', wm.encode('utf-8'))
with torch.no_grad(),\
torch.autocast("cuda"):
with torch.inference_mode(), torch.autocast("cuda"):
batch = make_batch_sd(image, txt=prompt, device=device, num_samples=num_samples)
z = model.get_first_stage_encoding(model.encode_first_stage(batch[model.first_stage_key])) # move to latent space
c = model.cond_stage_model.encode(batch["txt"])

View file

@ -79,8 +79,7 @@ def inpaint(sampler, image, mask, prompt, seed, scale, ddim_steps, num_samples=1
start_code = prng.randn(num_samples, 4, h // 8, w // 8)
start_code = torch.from_numpy(start_code).to(device=device, dtype=torch.float32)
with torch.no_grad(), \
torch.autocast("cuda"):
with torch.inference_mode(), torch.autocast("cuda"):
batch = make_batch_sd(image, mask, txt=prompt, device=device, num_samples=num_samples)
c = model.cond_stage_model.encode(batch["txt"])

View file

@ -64,8 +64,8 @@ def paint(sampler, image, prompt, seed, scale, h, w, steps, num_samples=1, callb
wm = "SDV2"
wm_encoder = WatermarkEncoder()
wm_encoder.set_watermark('bytes', wm.encode('utf-8'))
with torch.no_grad(),\
torch.autocast("cuda"):
with torch.inference_mode(), torch.autocast("cuda"):
batch = make_batch_sd(image, txt=prompt, device=device, num_samples=num_samples)
c = model.cond_stage_model.encode(batch["txt"])
c_cat = list()
@ -105,8 +105,10 @@ def paint(sampler, image, prompt, seed, scale, h, w, steps, num_samples=1, callb
x_T=start_code,
callback=callback
)
with torch.no_grad():
with torch.inference_mode():
x_samples_ddim = model.decode_first_stage(samples)
result = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
result = result.cpu().numpy().transpose(0, 2, 3, 1) * 255
st.text(f"upscaled image shape: {result.shape}")

View file

@ -232,9 +232,8 @@ def main(opt):
start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device)
precision_scope = autocast if opt.precision == "autocast" else nullcontext
with torch.no_grad(), \
precision_scope("cuda"), \
model.ema_scope():
with torch.inference_mode(), precision_scope("cuda"), model.ema_scope():
all_samples = list()
for n in trange(opt.n_iter, desc="Sampling"):
for prompts in tqdm(data, desc="data"):