From f70130a587eb0431c053c6366f2bc874657c78b2 Mon Sep 17 00:00:00 2001 From: Yewon Lim Date: Wed, 5 Jul 2023 14:30:41 +0000 Subject: [PATCH] add inpaint.py --- scripts/inpaint.py | 187 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 187 insertions(+) create mode 100644 scripts/inpaint.py diff --git a/scripts/inpaint.py b/scripts/inpaint.py new file mode 100644 index 0000000..2a724d8 --- /dev/null +++ b/scripts/inpaint.py @@ -0,0 +1,187 @@ +import sys +import cv2 +import torch +import numpy as np +import gradio as gr +from PIL import Image +from omegaconf import OmegaConf +from einops import repeat +from imwatermark import WatermarkEncoder +from pathlib import Path +import argparse + +from ldm.models.diffusion.ddim import DDIMSampler +from ldm.util import instantiate_from_config + + +torch.set_grad_enabled(False) + + +def put_watermark(img, wm_encoder=None): + if wm_encoder is not None: + img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) + img = wm_encoder.encode(img, 'dwtDct') + img = Image.fromarray(img[:, :, ::-1]) + return img + + +def initialize_model(config, ckpt): + config = OmegaConf.load(config) + model = instantiate_from_config(config.model) + + model.load_state_dict(torch.load(ckpt)["state_dict"], strict=False) + + device = torch.device( + "cuda") if torch.cuda.is_available() else torch.device("cpu") + model = model.to(device) + sampler = DDIMSampler(model) + + return sampler + + +def make_batch_sd( + image, + mask, + txt, + device, + num_samples=1): + image = np.array(image.convert("RGB")) + image = image[None].transpose(0, 3, 1, 2) + image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 + + mask = np.array(mask.convert("L")) + mask = mask.astype(np.float32) / 255.0 + mask = mask[None, None] + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + mask = torch.from_numpy(mask) + + masked_image = image * (mask < 0.5) + + batch = { + "image": repeat(image.to(device=device), "1 ... -> n ...", n=num_samples), + "txt": num_samples * [txt], + "mask": repeat(mask.to(device=device), "1 ... -> n ...", n=num_samples), + "masked_image": repeat(masked_image.to(device=device), "1 ... -> n ...", n=num_samples), + } + return batch + + +def inpaint(sampler, image, mask, prompt, seed, scale, ddim_steps, num_samples=1, w=512, h=512): + device = torch.device( + "cuda") if torch.cuda.is_available() else torch.device("cpu") + model = sampler.model + + print("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...") + wm = "SDV2" + wm_encoder = WatermarkEncoder() + wm_encoder.set_watermark('bytes', wm.encode('utf-8')) + + prng = np.random.RandomState(seed) + start_code = prng.randn(num_samples, 4, h // 8, w // 8) + start_code = torch.from_numpy(start_code).to( + device=device, dtype=torch.float32) + + with torch.no_grad(), \ + torch.autocast("cuda"): + batch = make_batch_sd(image, mask, txt=prompt, + device=device, num_samples=num_samples) + + c = model.cond_stage_model.encode(batch["txt"]) + + c_cat = list() + for ck in model.concat_keys: + cc = batch[ck].float() + if ck != model.masked_image_key: + bchw = [num_samples, 4, h // 8, w // 8] + cc = torch.nn.functional.interpolate(cc, size=bchw[-2:]) + else: + cc = model.get_first_stage_encoding( + model.encode_first_stage(cc)) + c_cat.append(cc) + c_cat = torch.cat(c_cat, dim=1) + + # cond + cond = {"c_concat": [c_cat], "c_crossattn": [c]} + + # uncond cond + uc_cross = model.get_unconditional_conditioning(num_samples, "") + uc_full = {"c_concat": [c_cat], "c_crossattn": [uc_cross]} + + shape = [model.channels, h // 8, w // 8] + samples_cfg, intermediates = sampler.sample( + ddim_steps, + num_samples, + shape, + cond, + verbose=False, + eta=1.0, + unconditional_guidance_scale=scale, + unconditional_conditioning=uc_full, + x_T=start_code, + ) + x_samples_ddim = model.decode_first_stage(samples_cfg) + + result = torch.clamp((x_samples_ddim + 1.0) / 2.0, + min=0.0, max=1.0) + + result = result.cpu().numpy().transpose(0, 2, 3, 1) * 255 + return [put_watermark(Image.fromarray(img.astype(np.uint8)), wm_encoder) for img in result] + +def pad_image(input_image): + pad_w, pad_h = np.max(((2, 2), np.ceil( + np.array(input_image.size) / 64).astype(int)), axis=0) * 64 - input_image.size + im_padded = Image.fromarray( + np.pad(np.array(input_image), ((0, pad_h), (0, pad_w), (0, 0)), mode='edge')) + return im_padded + +def predict(input_image, prompt, ddim_steps, num_samples, scale, seed, sampler): + init_image = input_image["image"].convert("RGB") + init_mask = input_image["mask"].convert("RGB") + image = pad_image(init_image) # resize to integer multiple of 32 + mask = pad_image(init_mask) # resize to integer multiple of 32 + width, height = image.size + print("Inpainting...", width, height) + + result = inpaint( + sampler=sampler, + image=image, + mask=mask, + prompt=prompt, + seed=seed, + scale=scale, + ddim_steps=ddim_steps, + num_samples=num_samples, + h=height, w=width + ) + + return result + + + +def parse_args(): + parser = argparse.ArgumentParser(description='Image inpainting') + parser.add_argument('--config', type=str, default="configs/stable-diffusion/v2-inpainting-inference.yaml", help='config path') + parser.add_argument('--ckpt', type=str, default="512-inpainting-ema.ckpt", help='Model checkpoint') + parser.add_argument('--src', type=str, help='Source image path') + parser.add_argument('--prompt', type=str, help='Description for source image') + parser.add_argument('--mask', type=str, help='Mask path') + parser.add_argument('--dir', type=str, default='', help='Directory where generated samples are saved') + parser.add_argument('--steps', type=int, default=45, help='Number of DDIM sample steps') + parser.add_argument('--n_sample', type=int, default=4, help='Number of samples') + args = parser.parse_args() + return args + + +if __name__ == "__main__": + import os.path as osp + + args = parse_args() + sampler = initialize_model(args.config, args.ckpt) + input_pair = {"image":Image.open(args.src), + "mask":Image.open(args.mask)} + results = predict(input_pair, args.prompt, args.steps, args.n_sample, 12, 991108, sampler) + for i, img in enumerate(results): + img.save(osp.join(args.dir, f"inpaint_{i}_{osp.basename(args.src)}")) + +