StableDiffusion/scripts/inpaint.py
2023-07-05 14:30:41 +00:00

187 lines
6.2 KiB
Python

import sys
import cv2
import torch
import numpy as np
import gradio as gr
from PIL import Image
from omegaconf import OmegaConf
from einops import repeat
from imwatermark import WatermarkEncoder
from pathlib import Path
import argparse
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.util import instantiate_from_config
torch.set_grad_enabled(False)
def put_watermark(img, wm_encoder=None):
if wm_encoder is not None:
img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
img = wm_encoder.encode(img, 'dwtDct')
img = Image.fromarray(img[:, :, ::-1])
return img
def initialize_model(config, ckpt):
config = OmegaConf.load(config)
model = instantiate_from_config(config.model)
model.load_state_dict(torch.load(ckpt)["state_dict"], strict=False)
device = torch.device(
"cuda") if torch.cuda.is_available() else torch.device("cpu")
model = model.to(device)
sampler = DDIMSampler(model)
return sampler
def make_batch_sd(
image,
mask,
txt,
device,
num_samples=1):
image = np.array(image.convert("RGB"))
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
mask = np.array(mask.convert("L"))
mask = mask.astype(np.float32) / 255.0
mask = mask[None, None]
mask[mask < 0.5] = 0
mask[mask >= 0.5] = 1
mask = torch.from_numpy(mask)
masked_image = image * (mask < 0.5)
batch = {
"image": repeat(image.to(device=device), "1 ... -> n ...", n=num_samples),
"txt": num_samples * [txt],
"mask": repeat(mask.to(device=device), "1 ... -> n ...", n=num_samples),
"masked_image": repeat(masked_image.to(device=device), "1 ... -> n ...", n=num_samples),
}
return batch
def inpaint(sampler, image, mask, prompt, seed, scale, ddim_steps, num_samples=1, w=512, h=512):
device = torch.device(
"cuda") if torch.cuda.is_available() else torch.device("cpu")
model = sampler.model
print("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...")
wm = "SDV2"
wm_encoder = WatermarkEncoder()
wm_encoder.set_watermark('bytes', wm.encode('utf-8'))
prng = np.random.RandomState(seed)
start_code = prng.randn(num_samples, 4, h // 8, w // 8)
start_code = torch.from_numpy(start_code).to(
device=device, dtype=torch.float32)
with torch.no_grad(), \
torch.autocast("cuda"):
batch = make_batch_sd(image, mask, txt=prompt,
device=device, num_samples=num_samples)
c = model.cond_stage_model.encode(batch["txt"])
c_cat = list()
for ck in model.concat_keys:
cc = batch[ck].float()
if ck != model.masked_image_key:
bchw = [num_samples, 4, h // 8, w // 8]
cc = torch.nn.functional.interpolate(cc, size=bchw[-2:])
else:
cc = model.get_first_stage_encoding(
model.encode_first_stage(cc))
c_cat.append(cc)
c_cat = torch.cat(c_cat, dim=1)
# cond
cond = {"c_concat": [c_cat], "c_crossattn": [c]}
# uncond cond
uc_cross = model.get_unconditional_conditioning(num_samples, "")
uc_full = {"c_concat": [c_cat], "c_crossattn": [uc_cross]}
shape = [model.channels, h // 8, w // 8]
samples_cfg, intermediates = sampler.sample(
ddim_steps,
num_samples,
shape,
cond,
verbose=False,
eta=1.0,
unconditional_guidance_scale=scale,
unconditional_conditioning=uc_full,
x_T=start_code,
)
x_samples_ddim = model.decode_first_stage(samples_cfg)
result = torch.clamp((x_samples_ddim + 1.0) / 2.0,
min=0.0, max=1.0)
result = result.cpu().numpy().transpose(0, 2, 3, 1) * 255
return [put_watermark(Image.fromarray(img.astype(np.uint8)), wm_encoder) for img in result]
def pad_image(input_image):
pad_w, pad_h = np.max(((2, 2), np.ceil(
np.array(input_image.size) / 64).astype(int)), axis=0) * 64 - input_image.size
im_padded = Image.fromarray(
np.pad(np.array(input_image), ((0, pad_h), (0, pad_w), (0, 0)), mode='edge'))
return im_padded
def predict(input_image, prompt, ddim_steps, num_samples, scale, seed, sampler):
init_image = input_image["image"].convert("RGB")
init_mask = input_image["mask"].convert("RGB")
image = pad_image(init_image) # resize to integer multiple of 32
mask = pad_image(init_mask) # resize to integer multiple of 32
width, height = image.size
print("Inpainting...", width, height)
result = inpaint(
sampler=sampler,
image=image,
mask=mask,
prompt=prompt,
seed=seed,
scale=scale,
ddim_steps=ddim_steps,
num_samples=num_samples,
h=height, w=width
)
return result
def parse_args():
parser = argparse.ArgumentParser(description='Image inpainting')
parser.add_argument('--config', type=str, default="configs/stable-diffusion/v2-inpainting-inference.yaml", help='config path')
parser.add_argument('--ckpt', type=str, default="512-inpainting-ema.ckpt", help='Model checkpoint')
parser.add_argument('--src', type=str, help='Source image path')
parser.add_argument('--prompt', type=str, help='Description for source image')
parser.add_argument('--mask', type=str, help='Mask path')
parser.add_argument('--dir', type=str, default='', help='Directory where generated samples are saved')
parser.add_argument('--steps', type=int, default=45, help='Number of DDIM sample steps')
parser.add_argument('--n_sample', type=int, default=4, help='Number of samples')
args = parser.parse_args()
return args
if __name__ == "__main__":
import os.path as osp
args = parse_args()
sampler = initialize_model(args.config, args.ckpt)
input_pair = {"image":Image.open(args.src),
"mask":Image.open(args.mask)}
results = predict(input_pair, args.prompt, args.steps, args.n_sample, 12, 991108, sampler)
for i, img in enumerate(results):
img.save(osp.join(args.dir, f"inpaint_{i}_{osp.basename(args.src)}"))