diff --git a/README.md b/README.md index c06f3c0..eb44b04 100644 --- a/README.md +++ b/README.md @@ -137,15 +137,28 @@ Note: The inference config for all model versions is designed to be used with EM For this reason `use_ema=False` is set in the configuration, otherwise the code will try to switch from non-EMA to EMA weights. -### Stable Diffusion Meets Karlo -![upscaling-x4](assets/stable-samples/stable-unclip/panda.jpg) -_++++++ NOTE: preliminary checkpoint for internal testing ++++++_ +### Stable unCLIP +_++++++ NOTE: preliminary checkpoints for internal testing ++++++_ -Recently, [KakaoBrain](https://kakaobrain.com/) openly released [Karlo](https://github.com/kakaobrain/karlo), a pretrained, large-scale replication of [unCLIP](https://arxiv.org/abs/2204.06125) (also known as DALL·E 2). -We introduce _Stable Karlo_, a combination of the Karlo CLIP image embedding prior, and Stable Diffusion v2.1. -More precisely, we finetuned SD 2.1 to accept a CLIP ViT-L/14 image embedding in addition to the text encodings. -This means that the model can be used to produce image variations in the style of unCLIP, but can also be combined with the -embedding prior of KARLO and directly decodes to 768x768 pixel resolution. +[unCLIP](https://openai.com/dall-e-2/) is the approach behind OpenAI's [DALL·E 2](https://openai.com/dall-e-2/), +trained to invert CLIP image embeddings. +We finetuned SD 2.1 to accept a CLIP ViT-L/14 image embedding in addition to the text encodings. +This means that the model can be used to produce image variations, but can also be combined with a text-to-image +embedding prior to yield a full text-to-image model at 768x768 resolution. +We provide two models, trained on OpenAI CLIP-L and OpenCLIP-H image embeddings, respectively, available +_[TODO: +++prelim private upload on HF+++]_ from [https://huggingface.co/stabilityai/stable-unclip-preview](https://huggingface.co/stabilityai/stable-unclip-preview). +To use them, download from Hugging Face, and put and the weights into the `checkpoints` folder. +#### Image Variations +![image-variations-h](assets/stable-samples/stable-unclip/castle.jpg) +![image-variations-h](assets/stable-samples/stable-unclip/cornmen.jpg) + +_++TODO: Input images from the DIV2K dataset. Proceed with care++_ + +#### Stable Diffusion Meets Karlo +![panda](assets/stable-samples/stable-unclip/panda.jpg) + +Recently, [KakaoBrain](https://kakaobrain.com/) openly released [Karlo](https://github.com/kakaobrain/karlo), a pretrained, large-scale replication of [unCLIP](https://arxiv.org/abs/2204.06125). +We introduce _Stable Karlo_, a combination of the Karlo CLIP image embedding prior, and Stable Diffusion v2.1-768. To run the model, first download the KARLO checkpoints ```shell @@ -156,7 +169,7 @@ wget https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/0b623 wget https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/85626483eaca9f581e2a78d31ff905ca/prior-ckpt-step%3D01000000-of-01000000.ckpt cd ../../ ``` -and the finetuned SD2.1 checkpoint [+++prelim private upload on HF+++] from [https://huggingface.co/stabilityai/stable-unclip-preview](https://huggingface.co/stabilityai/stable-unclip-preview), and put the ckpt into the `checkpoints folder` +and the finetuned SD2.1 unCLIP-L checkpoint _[TODO: +++prelim private upload on HF+++]_ from [https://huggingface.co/stabilityai/stable-unclip-preview](https://huggingface.co/stabilityai/stable-unclip-preview), and put the ckpt into the `checkpoints folder` Then, run diff --git a/checkpoints/checkpoints.txt b/checkpoints/checkpoints.txt new file mode 100644 index 0000000..d92df31 --- /dev/null +++ b/checkpoints/checkpoints.txt @@ -0,0 +1 @@ +Put unCLIP checkpoints here. \ No newline at end of file diff --git a/configs/stable-diffusion/v2-1-stable-unclip-h-inference.yaml b/configs/stable-diffusion/v2-1-stable-unclip-h-inference.yaml new file mode 100644 index 0000000..6c41acc --- /dev/null +++ b/configs/stable-diffusion/v2-1-stable-unclip-h-inference.yaml @@ -0,0 +1,80 @@ +model: + base_learning_rate: 1.0e-04 + target: ldm.models.diffusion.ddpm.ImageEmbeddingConditionedLatentDiffusion + params: + embedding_dropout: 0.25 + parameterization: "v" + linear_start: 0.00085 + linear_end: 0.0120 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "jpg" + cond_stage_key: "txt" + image_size: 96 + channels: 4 + cond_stage_trainable: false + conditioning_key: crossattn-adm + scale_factor: 0.18215 + monitor: val/loss_simple_ema + + embedder_config: + target: ldm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder + + noise_aug_config: + target: ldm.modules.encoders.modules.CLIPEmbeddingNoiseAugmentation + params: + timestep_dim: 1024 + noise_schedule_config: + timesteps: 1000 + beta_schedule: squaredcos_cap_v2 + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + num_classes: "sequential" + adm_in_channels: 2048 + use_checkpoint: True + image_size: 32 # unused + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_head_channels: 64 # need to fix for flash-attn + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: 1 + context_dim: 1024 + spatial_transformer_attn_type: "softmax-xformers" + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + attn_type: "vanilla-xformers" + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [ ] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder + params: + freeze: True + layer: "penultimate" diff --git a/configs/stable-diffusion/v2-1-stable-karlo-inference.yaml b/configs/stable-diffusion/v2-1-stable-unclip-l-inference.yaml similarity index 100% rename from configs/stable-diffusion/v2-1-stable-karlo-inference.yaml rename to configs/stable-diffusion/v2-1-stable-unclip-l-inference.yaml diff --git a/ldm/models/diffusion/dpm_solver/dpm_solver.py b/ldm/models/diffusion/dpm_solver/dpm_solver.py index 095e5ba..fbe7aac 100644 --- a/ldm/models/diffusion/dpm_solver/dpm_solver.py +++ b/ldm/models/diffusion/dpm_solver/dpm_solver.py @@ -307,7 +307,16 @@ def model_wrapper( else: x_in = torch.cat([x] * 2) t_in = torch.cat([t_continuous] * 2) - c_in = torch.cat([unconditional_condition, condition]) + if isinstance(condition, dict): + assert isinstance(unconditional_conditioning, dict) + c_in = dict() + for k in c: + if isinstance(c[k], list): + c_in[k] = [torch.cat([unconditional_conditioning[k][i], c[k][i]]) for i in range(len(c[k]))] + else: + c_in[k] = torch.cat([unconditional_conditioning[k], c[k]]) + else: + c_in = torch.cat([unconditional_condition, condition]) noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2) return noise_uncond + guidance_scale * (noise - noise_uncond) diff --git a/ldm/modules/encoders/modules.py b/ldm/modules/encoders/modules.py index fcc5826..d89481f 100644 --- a/ldm/modules/encoders/modules.py +++ b/ldm/modules/encoders/modules.py @@ -6,7 +6,7 @@ from torch.utils.checkpoint import checkpoint from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel import open_clip -from ldm.util import default, count_params +from ldm.util import default, count_params, autocast class AbstractEncoder(nn.Module): @@ -232,6 +232,63 @@ class FrozenOpenCLIPEmbedder(AbstractEncoder): return self(text) +class FrozenOpenCLIPImageEmbedder(AbstractEncoder): + """ + Uses the OpenCLIP vision transformer encoder for images + """ + def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77, + freeze=True, layer="pooled", antialias=True, ucg_rate=0.): + super().__init__() + model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), + pretrained=version, ) + del model.transformer + self.model = model + + self.device = device + self.max_length = max_length + if freeze: + self.freeze() + self.layer = layer + if self.layer == "penultimate": + raise NotImplementedError() + self.layer_idx = 1 + + self.antialias = antialias + + self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) + self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) + self.ucg_rate = ucg_rate + + def preprocess(self, x): + # normalize to [0,1] + x = kornia.geometry.resize(x, (224, 224), + interpolation='bicubic', align_corners=True, + antialias=self.antialias) + x = (x + 1.) / 2. + # renormalize according to clip + x = kornia.enhance.normalize(x, self.mean, self.std) + return x + + def freeze(self): + self.model = self.model.eval() + for param in self.parameters(): + param.requires_grad = False + + @autocast + def forward(self, image, no_dropout=False): + z = self.encode_with_vision_transformer(image) + if self.ucg_rate > 0. and not no_dropout: + z = torch.bernoulli((1.-self.ucg_rate) * torch.ones(z.shape[0], device=z.device))[:, None] * z + return z + + def encode_with_vision_transformer(self, img): + img = self.preprocess(img) + x = self.model.visual(img) + return x + + def encode(self, text): + return self(text) + class FrozenCLIPT5Encoder(AbstractEncoder): def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device="cuda", clip_max_length=77, t5_max_length=77): diff --git a/ldm/util.py b/ldm/util.py index 8c09ca1..9ede259 100644 --- a/ldm/util.py +++ b/ldm/util.py @@ -8,6 +8,16 @@ from inspect import isfunction from PIL import Image, ImageDraw, ImageFont +def autocast(f): + def do_autocast(*args, **kwargs): + with torch.cuda.amp.autocast(enabled=True, + dtype=torch.get_autocast_gpu_dtype(), + cache_enabled=torch.is_autocast_cache_enabled()): + return f(*args, **kwargs) + + return do_autocast + + def log_txt_as_img(wh, xc, size=10): # wh a tuple of (width, height) # xc a list of captions to plot diff --git a/scripts/streamlit/stablekarlo.py b/scripts/streamlit/stableunclip.py similarity index 94% rename from scripts/streamlit/stablekarlo.py rename to scripts/streamlit/stableunclip.py index 5121bd2..4af9051 100644 --- a/scripts/streamlit/stablekarlo.py +++ b/scripts/streamlit/stableunclip.py @@ -22,10 +22,11 @@ from ldm.models.diffusion.dpm_solver import DPMSolverSampler torch.set_grad_enabled(False) PROMPTS_ROOT = "scripts/prompts/" -SAVE_PATH = "outputs/demo/stable-karlo/" +SAVE_PATH = "outputs/demo/stable-unclip/" VERSION2SPECS = { - "Stable Karlo": {"H": 768, "W": 768, "C": 4, "f": 8}, + "Stable unCLIP-L": {"H": 768, "W": 768, "C": 4, "f": 8}, + "Stable unOpenCLIP-H": {"H": 768, "W": 768, "C": 4, "f": 8}, "Full Karlo": {} } @@ -193,12 +194,16 @@ def torch2np(x): @st.cache(allow_output_mutation=True, suppress_st_warning=True) -def init(version="Stable Karlo", load_karlo_prior=False): +def init(version="Stable unCLIP-L", load_karlo_prior=False): state = dict() if not "model" in state: - if version == "Stable Karlo": - config = "configs/stable-diffusion/v2-1-stable-karlo-inference.yaml" - ckpt = "checkpoints/v2-1-stable-unclip-ft.ckpt" + if version == "Stable unCLIP-L": + config = "configs/stable-diffusion/v2-1-stable-unclip-l-inference.yaml" + ckpt = "checkpoints/v2-1-stable-unclip-l-ft.ckpt" + + elif version == "Stable unOpenCLIP-H": + config = "configs/stable-diffusion/v2-1-stable-unclip-h-inference.yaml" + ckpt = "checkpoints/v2-1-stable-unclip-h-ft.ckpt" elif version == "Full Karlo": from ldm.modules.karlo.kakao.sampler import T2ISampler @@ -223,7 +228,7 @@ def init(version="Stable Karlo", load_karlo_prior=False): from ldm.modules.karlo.kakao.sampler import PriorSampler st.info("Loading KARLO CLIP prior...") karlo_prior = PriorSampler.from_pretrained( - root_dir="/fsx/robin/checkpoints/karlo_models", + root_dir="checkpoints/karlo_models", clip_model_path="ViT-L-14.pt", clip_stat_path="ViT-L-14_stats.th", sampling_type="default", @@ -266,10 +271,10 @@ def load_model_from_config(config, ckpt, verbose=False, vae_sd=None): if __name__ == "__main__": - st.title("Stable Karlo") + st.title("Stable unCLIP") mode = "txt2img" version = st.selectbox("Model Version", list(VERSION2SPECS.keys()), 0) - use_karlo = st.checkbox("Use KARLO prior", False) + use_karlo = st.checkbox("Use KARLO prior", False) and version in ["Stable unCLIP-L"] state = init(version=version, load_karlo_prior=use_karlo) st.info(state["msg"]) prompt = st.text_input("Prompt", "a professional photograph of an astronaut riding a horse")