mirror of
https://github.com/Stability-AI/stablediffusion.git
synced 2024-12-22 23:55:00 +00:00
Switch from no_grad to inference mode for scripts
This commit is contained in:
parent
33910c386e
commit
e6050f3e58
8 changed files with 18 additions and 18 deletions
|
@ -64,8 +64,7 @@ def paint(sampler, image, prompt, t_enc, seed, scale, num_samples=1, callback=No
|
||||||
wm_encoder = WatermarkEncoder()
|
wm_encoder = WatermarkEncoder()
|
||||||
wm_encoder.set_watermark('bytes', wm.encode('utf-8'))
|
wm_encoder.set_watermark('bytes', wm.encode('utf-8'))
|
||||||
|
|
||||||
with torch.no_grad(),\
|
with torch.inference_mode(), torch.autocast("cuda"):
|
||||||
torch.autocast("cuda"):
|
|
||||||
batch = make_batch_sd(
|
batch = make_batch_sd(
|
||||||
image, txt=prompt, device=device, num_samples=num_samples)
|
image, txt=prompt, device=device, num_samples=num_samples)
|
||||||
z = model.get_first_stage_encoding(model.encode_first_stage(
|
z = model.get_first_stage_encoding(model.encode_first_stage(
|
||||||
|
|
|
@ -81,8 +81,7 @@ def inpaint(sampler, image, mask, prompt, seed, scale, ddim_steps, num_samples=1
|
||||||
start_code = torch.from_numpy(start_code).to(
|
start_code = torch.from_numpy(start_code).to(
|
||||||
device=device, dtype=torch.float32)
|
device=device, dtype=torch.float32)
|
||||||
|
|
||||||
with torch.no_grad(), \
|
with torch.inference_mode(), torch.autocast("cuda"):
|
||||||
torch.autocast("cuda"):
|
|
||||||
batch = make_batch_sd(image, mask, txt=prompt,
|
batch = make_batch_sd(image, mask, txt=prompt,
|
||||||
device=device, num_samples=num_samples)
|
device=device, num_samples=num_samples)
|
||||||
|
|
||||||
|
|
|
@ -67,8 +67,8 @@ def paint(sampler, image, prompt, seed, scale, h, w, steps, num_samples=1, callb
|
||||||
wm = "SDV2"
|
wm = "SDV2"
|
||||||
wm_encoder = WatermarkEncoder()
|
wm_encoder = WatermarkEncoder()
|
||||||
wm_encoder.set_watermark('bytes', wm.encode('utf-8'))
|
wm_encoder.set_watermark('bytes', wm.encode('utf-8'))
|
||||||
with torch.no_grad(),\
|
|
||||||
torch.autocast("cuda"):
|
with torch.inference_mode(), torch.autocast("cuda"):
|
||||||
batch = make_batch_sd(
|
batch = make_batch_sd(
|
||||||
image, txt=prompt, device=device, num_samples=num_samples)
|
image, txt=prompt, device=device, num_samples=num_samples)
|
||||||
c = model.cond_stage_model.encode(batch["txt"])
|
c = model.cond_stage_model.encode(batch["txt"])
|
||||||
|
@ -112,8 +112,10 @@ def paint(sampler, image, prompt, seed, scale, h, w, steps, num_samples=1, callb
|
||||||
x_T=start_code,
|
x_T=start_code,
|
||||||
callback=callback
|
callback=callback
|
||||||
)
|
)
|
||||||
with torch.no_grad():
|
|
||||||
|
with torch.inference_mode():
|
||||||
x_samples_ddim = model.decode_first_stage(samples)
|
x_samples_ddim = model.decode_first_stage(samples)
|
||||||
|
|
||||||
result = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
result = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
result = result.cpu().numpy().transpose(0, 2, 3, 1) * 255
|
result = result.cpu().numpy().transpose(0, 2, 3, 1) * 255
|
||||||
return [put_watermark(Image.fromarray(img.astype(np.uint8)), wm_encoder) for img in result]
|
return [put_watermark(Image.fromarray(img.astype(np.uint8)), wm_encoder) for img in result]
|
||||||
|
|
|
@ -230,7 +230,8 @@ def main():
|
||||||
print(f"target t_enc is {t_enc} steps")
|
print(f"target t_enc is {t_enc} steps")
|
||||||
|
|
||||||
precision_scope = autocast if opt.precision == "autocast" else nullcontext
|
precision_scope = autocast if opt.precision == "autocast" else nullcontext
|
||||||
with torch.no_grad():
|
|
||||||
|
with torch.inference_mode():
|
||||||
with precision_scope("cuda"):
|
with precision_scope("cuda"):
|
||||||
with model.ema_scope():
|
with model.ema_scope():
|
||||||
all_samples = list()
|
all_samples = list()
|
||||||
|
|
|
@ -61,8 +61,7 @@ def paint(sampler, image, prompt, t_enc, seed, scale, num_samples=1, callback=No
|
||||||
wm_encoder = WatermarkEncoder()
|
wm_encoder = WatermarkEncoder()
|
||||||
wm_encoder.set_watermark('bytes', wm.encode('utf-8'))
|
wm_encoder.set_watermark('bytes', wm.encode('utf-8'))
|
||||||
|
|
||||||
with torch.no_grad(),\
|
with torch.inference_mode(), torch.autocast("cuda"):
|
||||||
torch.autocast("cuda"):
|
|
||||||
batch = make_batch_sd(image, txt=prompt, device=device, num_samples=num_samples)
|
batch = make_batch_sd(image, txt=prompt, device=device, num_samples=num_samples)
|
||||||
z = model.get_first_stage_encoding(model.encode_first_stage(batch[model.first_stage_key])) # move to latent space
|
z = model.get_first_stage_encoding(model.encode_first_stage(batch[model.first_stage_key])) # move to latent space
|
||||||
c = model.cond_stage_model.encode(batch["txt"])
|
c = model.cond_stage_model.encode(batch["txt"])
|
||||||
|
|
|
@ -79,8 +79,7 @@ def inpaint(sampler, image, mask, prompt, seed, scale, ddim_steps, num_samples=1
|
||||||
start_code = prng.randn(num_samples, 4, h // 8, w // 8)
|
start_code = prng.randn(num_samples, 4, h // 8, w // 8)
|
||||||
start_code = torch.from_numpy(start_code).to(device=device, dtype=torch.float32)
|
start_code = torch.from_numpy(start_code).to(device=device, dtype=torch.float32)
|
||||||
|
|
||||||
with torch.no_grad(), \
|
with torch.inference_mode(), torch.autocast("cuda"):
|
||||||
torch.autocast("cuda"):
|
|
||||||
batch = make_batch_sd(image, mask, txt=prompt, device=device, num_samples=num_samples)
|
batch = make_batch_sd(image, mask, txt=prompt, device=device, num_samples=num_samples)
|
||||||
|
|
||||||
c = model.cond_stage_model.encode(batch["txt"])
|
c = model.cond_stage_model.encode(batch["txt"])
|
||||||
|
|
|
@ -64,8 +64,8 @@ def paint(sampler, image, prompt, seed, scale, h, w, steps, num_samples=1, callb
|
||||||
wm = "SDV2"
|
wm = "SDV2"
|
||||||
wm_encoder = WatermarkEncoder()
|
wm_encoder = WatermarkEncoder()
|
||||||
wm_encoder.set_watermark('bytes', wm.encode('utf-8'))
|
wm_encoder.set_watermark('bytes', wm.encode('utf-8'))
|
||||||
with torch.no_grad(),\
|
|
||||||
torch.autocast("cuda"):
|
with torch.inference_mode(), torch.autocast("cuda"):
|
||||||
batch = make_batch_sd(image, txt=prompt, device=device, num_samples=num_samples)
|
batch = make_batch_sd(image, txt=prompt, device=device, num_samples=num_samples)
|
||||||
c = model.cond_stage_model.encode(batch["txt"])
|
c = model.cond_stage_model.encode(batch["txt"])
|
||||||
c_cat = list()
|
c_cat = list()
|
||||||
|
@ -105,8 +105,10 @@ def paint(sampler, image, prompt, seed, scale, h, w, steps, num_samples=1, callb
|
||||||
x_T=start_code,
|
x_T=start_code,
|
||||||
callback=callback
|
callback=callback
|
||||||
)
|
)
|
||||||
with torch.no_grad():
|
|
||||||
|
with torch.inference_mode():
|
||||||
x_samples_ddim = model.decode_first_stage(samples)
|
x_samples_ddim = model.decode_first_stage(samples)
|
||||||
|
|
||||||
result = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
result = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
result = result.cpu().numpy().transpose(0, 2, 3, 1) * 255
|
result = result.cpu().numpy().transpose(0, 2, 3, 1) * 255
|
||||||
st.text(f"upscaled image shape: {result.shape}")
|
st.text(f"upscaled image shape: {result.shape}")
|
||||||
|
|
|
@ -232,9 +232,8 @@ def main(opt):
|
||||||
start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device)
|
start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device)
|
||||||
|
|
||||||
precision_scope = autocast if opt.precision == "autocast" else nullcontext
|
precision_scope = autocast if opt.precision == "autocast" else nullcontext
|
||||||
with torch.no_grad(), \
|
|
||||||
precision_scope("cuda"), \
|
with torch.inference_mode(), precision_scope("cuda"), model.ema_scope():
|
||||||
model.ema_scope():
|
|
||||||
all_samples = list()
|
all_samples = list()
|
||||||
for n in trange(opt.n_iter, desc="Sampling"):
|
for n in trange(opt.n_iter, desc="Sampling"):
|
||||||
for prompts in tqdm(data, desc="data"):
|
for prompts in tqdm(data, desc="data"):
|
||||||
|
|
Loading…
Reference in a new issue