mirror of
https://github.com/Stability-AI/stablediffusion.git
synced 2024-12-21 23:24:59 +00:00
Add depth2img Gradio demo
This commit is contained in:
parent
cccfb98636
commit
05aea715a3
2 changed files with 192 additions and 2 deletions
10
README.md
10
README.md
|
@ -136,12 +136,18 @@ To augment the well-established [img2img](https://github.com/CompVis/stable-diff
|
|||
Note that the original method for image modification introduces significant semantic changes w.r.t. the initial image.
|
||||
If that is not desired, download our [depth-conditional stable diffusion](https://huggingface.co/stabilityai/stable-diffusion-2-depth) model and the `dpt_hybrid` MiDaS [model weights](https://github.com/intel-isl/DPT/releases/download/1_0/dpt_hybrid-midas-501f0c75.pt), place the latter in a folder `midas_models` and sample via
|
||||
```
|
||||
python scripts/streamlit/depth2img.py streamlit run scripts/demo/depth2img.py configs/stable-diffusion/v2-midas-inference.yaml <path-to-ckpt>
|
||||
python scripts/gradio/depth2img.py configs/stable-diffusion/v2-midas-inference.yaml <path-to-ckpt>
|
||||
```
|
||||
|
||||
or
|
||||
|
||||
```
|
||||
streamlit run scripts/streamlit/depth2img.py configs/stable-diffusion/v2-midas-inference.yaml <path-to-ckpt>
|
||||
```
|
||||
|
||||
This method can be used on the samples of the base model itself.
|
||||
For example, take [this sample](assets/stable-samples/depth2img/old_man.png) generated by an anonymous discord user.
|
||||
Using the [streamlit](https://streamlit.io/) script `depth2img.py`, the MiDaS model first infers a monocular depth estimate given this input,
|
||||
Using the [gradio](https://gradio.app) or [streamlit](https://streamlit.io/) script `depth2img.py`, the MiDaS model first infers a monocular depth estimate given this input,
|
||||
and the diffusion model is then conditioned on the (relative) depth output.
|
||||
|
||||
<p align="center">
|
||||
|
|
184
scripts/gradio/depth2img.py
Normal file
184
scripts/gradio/depth2img.py
Normal file
|
@ -0,0 +1,184 @@
|
|||
import sys
|
||||
import torch
|
||||
import numpy as np
|
||||
import gradio as gr
|
||||
from PIL import Image
|
||||
from omegaconf import OmegaConf
|
||||
from einops import repeat, rearrange
|
||||
from pytorch_lightning import seed_everything
|
||||
from imwatermark import WatermarkEncoder
|
||||
|
||||
from scripts.txt2img import put_watermark
|
||||
from ldm.util import instantiate_from_config
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
from ldm.data.util import AddMiDaS
|
||||
|
||||
torch.set_grad_enabled(False)
|
||||
|
||||
|
||||
def initialize_model(config, ckpt):
|
||||
config = OmegaConf.load(config)
|
||||
model = instantiate_from_config(config.model)
|
||||
model.load_state_dict(torch.load(ckpt)["state_dict"], strict=False)
|
||||
|
||||
device = torch.device(
|
||||
"cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
model = model.to(device)
|
||||
sampler = DDIMSampler(model)
|
||||
return sampler
|
||||
|
||||
|
||||
def make_batch_sd(
|
||||
image,
|
||||
txt,
|
||||
device,
|
||||
num_samples=1,
|
||||
model_type="dpt_hybrid"
|
||||
):
|
||||
image = np.array(image.convert("RGB"))
|
||||
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
|
||||
# sample['jpg'] is tensor hwc in [-1, 1] at this point
|
||||
midas_trafo = AddMiDaS(model_type=model_type)
|
||||
batch = {
|
||||
"jpg": image,
|
||||
"txt": num_samples * [txt],
|
||||
}
|
||||
batch = midas_trafo(batch)
|
||||
batch["jpg"] = rearrange(batch["jpg"], 'h w c -> 1 c h w')
|
||||
batch["jpg"] = repeat(batch["jpg"].to(device=device),
|
||||
"1 ... -> n ...", n=num_samples)
|
||||
batch["midas_in"] = repeat(torch.from_numpy(batch["midas_in"][None, ...]).to(
|
||||
device=device), "1 ... -> n ...", n=num_samples)
|
||||
return batch
|
||||
|
||||
|
||||
def paint(sampler, image, prompt, t_enc, seed, scale, num_samples=1, callback=None,
|
||||
do_full_sample=False):
|
||||
device = torch.device(
|
||||
"cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
model = sampler.model
|
||||
seed_everything(seed)
|
||||
|
||||
print("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...")
|
||||
wm = "SDV2"
|
||||
wm_encoder = WatermarkEncoder()
|
||||
wm_encoder.set_watermark('bytes', wm.encode('utf-8'))
|
||||
|
||||
with torch.no_grad(),\
|
||||
torch.autocast("cuda"):
|
||||
batch = make_batch_sd(
|
||||
image, txt=prompt, device=device, num_samples=num_samples)
|
||||
z = model.get_first_stage_encoding(model.encode_first_stage(
|
||||
batch[model.first_stage_key])) # move to latent space
|
||||
c = model.cond_stage_model.encode(batch["txt"])
|
||||
c_cat = list()
|
||||
for ck in model.concat_keys:
|
||||
cc = batch[ck]
|
||||
cc = model.depth_model(cc)
|
||||
depth_min, depth_max = torch.amin(cc, dim=[1, 2, 3], keepdim=True), torch.amax(cc, dim=[1, 2, 3],
|
||||
keepdim=True)
|
||||
display_depth = (cc - depth_min) / (depth_max - depth_min)
|
||||
depth_image = Image.fromarray(
|
||||
(display_depth[0, 0, ...].cpu().numpy() * 255.).astype(np.uint8))
|
||||
cc = torch.nn.functional.interpolate(
|
||||
cc,
|
||||
size=z.shape[2:],
|
||||
mode="bicubic",
|
||||
align_corners=False,
|
||||
)
|
||||
depth_min, depth_max = torch.amin(cc, dim=[1, 2, 3], keepdim=True), torch.amax(cc, dim=[1, 2, 3],
|
||||
keepdim=True)
|
||||
cc = 2. * (cc - depth_min) / (depth_max - depth_min) - 1.
|
||||
c_cat.append(cc)
|
||||
c_cat = torch.cat(c_cat, dim=1)
|
||||
# cond
|
||||
cond = {"c_concat": [c_cat], "c_crossattn": [c]}
|
||||
|
||||
# uncond cond
|
||||
uc_cross = model.get_unconditional_conditioning(num_samples, "")
|
||||
uc_full = {"c_concat": [c_cat], "c_crossattn": [uc_cross]}
|
||||
if not do_full_sample:
|
||||
# encode (scaled latent)
|
||||
z_enc = sampler.stochastic_encode(
|
||||
z, torch.tensor([t_enc] * num_samples).to(model.device))
|
||||
else:
|
||||
z_enc = torch.randn_like(z)
|
||||
# decode it
|
||||
samples = sampler.decode(z_enc, cond, t_enc, unconditional_guidance_scale=scale,
|
||||
unconditional_conditioning=uc_full, callback=callback)
|
||||
x_samples_ddim = model.decode_first_stage(samples)
|
||||
result = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
result = result.cpu().numpy().transpose(0, 2, 3, 1) * 255
|
||||
return [depth_image] + [put_watermark(Image.fromarray(img.astype(np.uint8)), wm_encoder) for img in result]
|
||||
|
||||
|
||||
def pad_image(input_image):
|
||||
pad_w, pad_h = np.max(((2, 2), np.ceil(
|
||||
np.array(input_image.size) / 64).astype(int)), axis=0) * 64 - input_image.size
|
||||
im_padded = Image.fromarray(
|
||||
np.pad(np.array(input_image), ((0, pad_h), (0, pad_w), (0, 0)), mode='edge'))
|
||||
return im_padded
|
||||
|
||||
|
||||
def predict(input_image, prompt, steps, num_samples, scale, seed, eta, strength):
|
||||
init_image = input_image.convert("RGB")
|
||||
image = pad_image(init_image) # resize to integer multiple of 32
|
||||
|
||||
sampler.make_schedule(steps, ddim_eta=eta, verbose=True)
|
||||
assert 0. <= strength <= 1., 'can only work with strength in [0.0, 1.0]'
|
||||
do_full_sample = strength == 1.
|
||||
t_enc = min(int(strength * steps), steps-1)
|
||||
result = paint(
|
||||
sampler=sampler,
|
||||
image=image,
|
||||
prompt=prompt,
|
||||
t_enc=t_enc,
|
||||
seed=seed,
|
||||
scale=scale,
|
||||
num_samples=num_samples,
|
||||
callback=None,
|
||||
do_full_sample=do_full_sample
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
sampler = initialize_model(sys.argv[1], sys.argv[2])
|
||||
|
||||
block = gr.Blocks().queue()
|
||||
with block:
|
||||
with gr.Row():
|
||||
gr.Markdown("## Stable Diffusion Depth2Img")
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
input_image = gr.Image(source='upload', type="pil")
|
||||
prompt = gr.Textbox(label="Prompt")
|
||||
run_button = gr.Button(label="Run")
|
||||
with gr.Accordion("Advanced options", open=False):
|
||||
num_samples = gr.Slider(
|
||||
label="Images", minimum=1, maximum=4, value=1, step=1)
|
||||
ddim_steps = gr.Slider(label="Steps", minimum=1,
|
||||
maximum=50, value=50, step=1)
|
||||
scale = gr.Slider(
|
||||
label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1
|
||||
)
|
||||
strength = gr.Slider(
|
||||
label="Strength", minimum=0.0, maximum=1.0, value=0.9, step=0.01
|
||||
)
|
||||
seed = gr.Slider(
|
||||
label="Seed",
|
||||
minimum=0,
|
||||
maximum=2147483647,
|
||||
step=1,
|
||||
randomize=True,
|
||||
)
|
||||
eta = gr.Number(label="eta (DDIM)", value=0.0)
|
||||
with gr.Column():
|
||||
gallery = gr.Gallery(label="Generated images", show_label=False).style(
|
||||
grid=[2], height="auto")
|
||||
|
||||
run_button.click(fn=predict, inputs=[
|
||||
input_image, prompt, ddim_steps, num_samples, scale, seed, eta, strength], outputs=[gallery])
|
||||
|
||||
|
||||
block.launch()
|
Loading…
Reference in a new issue