update for openclip release

This commit is contained in:
Robin Rombach 2023-01-27 14:59:02 +01:00
parent 639b3f3f01
commit 5ca06055d4
8 changed files with 195 additions and 20 deletions

View file

@ -137,15 +137,28 @@ Note: The inference config for all model versions is designed to be used with EM
For this reason `use_ema=False` is set in the configuration, otherwise the code will try to switch from For this reason `use_ema=False` is set in the configuration, otherwise the code will try to switch from
non-EMA to EMA weights. non-EMA to EMA weights.
### Stable Diffusion Meets Karlo ### Stable unCLIP
![upscaling-x4](assets/stable-samples/stable-unclip/panda.jpg) _++++++ NOTE: preliminary checkpoints for internal testing ++++++_
_++++++ NOTE: preliminary checkpoint for internal testing ++++++_
Recently, [KakaoBrain](https://kakaobrain.com/) openly released [Karlo](https://github.com/kakaobrain/karlo), a pretrained, large-scale replication of [unCLIP](https://arxiv.org/abs/2204.06125) (also known as DALL·E 2). [unCLIP](https://openai.com/dall-e-2/) is the approach behind OpenAI's [DALL·E 2](https://openai.com/dall-e-2/),
We introduce _Stable Karlo_, a combination of the Karlo CLIP image embedding prior, and Stable Diffusion v2.1. trained to invert CLIP image embeddings.
More precisely, we finetuned SD 2.1 to accept a CLIP ViT-L/14 image embedding in addition to the text encodings. We finetuned SD 2.1 to accept a CLIP ViT-L/14 image embedding in addition to the text encodings.
This means that the model can be used to produce image variations in the style of unCLIP, but can also be combined with the This means that the model can be used to produce image variations, but can also be combined with a text-to-image
embedding prior of KARLO and directly decodes to 768x768 pixel resolution. embedding prior to yield a full text-to-image model at 768x768 resolution.
We provide two models, trained on OpenAI CLIP-L and OpenCLIP-H image embeddings, respectively, available
_[TODO: +++prelim private upload on HF+++]_ from [https://huggingface.co/stabilityai/stable-unclip-preview](https://huggingface.co/stabilityai/stable-unclip-preview).
To use them, download from Hugging Face, and put and the weights into the `checkpoints` folder.
#### Image Variations
![image-variations-h](assets/stable-samples/stable-unclip/castle.jpg)
![image-variations-h](assets/stable-samples/stable-unclip/cornmen.jpg)
_++TODO: Input images from the DIV2K dataset. Proceed with care++_
#### Stable Diffusion Meets Karlo
![panda](assets/stable-samples/stable-unclip/panda.jpg)
Recently, [KakaoBrain](https://kakaobrain.com/) openly released [Karlo](https://github.com/kakaobrain/karlo), a pretrained, large-scale replication of [unCLIP](https://arxiv.org/abs/2204.06125).
We introduce _Stable Karlo_, a combination of the Karlo CLIP image embedding prior, and Stable Diffusion v2.1-768.
To run the model, first download the KARLO checkpoints To run the model, first download the KARLO checkpoints
```shell ```shell
@ -156,7 +169,7 @@ wget https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/0b623
wget https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/85626483eaca9f581e2a78d31ff905ca/prior-ckpt-step%3D01000000-of-01000000.ckpt wget https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/85626483eaca9f581e2a78d31ff905ca/prior-ckpt-step%3D01000000-of-01000000.ckpt
cd ../../ cd ../../
``` ```
and the finetuned SD2.1 checkpoint [+++prelim private upload on HF+++] from [https://huggingface.co/stabilityai/stable-unclip-preview](https://huggingface.co/stabilityai/stable-unclip-preview), and put the ckpt into the `checkpoints folder` and the finetuned SD2.1 unCLIP-L checkpoint _[TODO: +++prelim private upload on HF+++]_ from [https://huggingface.co/stabilityai/stable-unclip-preview](https://huggingface.co/stabilityai/stable-unclip-preview), and put the ckpt into the `checkpoints folder`
Then, run Then, run

View file

@ -0,0 +1 @@
Put unCLIP checkpoints here.

View file

@ -0,0 +1,80 @@
model:
base_learning_rate: 1.0e-04
target: ldm.models.diffusion.ddpm.ImageEmbeddingConditionedLatentDiffusion
params:
embedding_dropout: 0.25
parameterization: "v"
linear_start: 0.00085
linear_end: 0.0120
log_every_t: 200
timesteps: 1000
first_stage_key: "jpg"
cond_stage_key: "txt"
image_size: 96
channels: 4
cond_stage_trainable: false
conditioning_key: crossattn-adm
scale_factor: 0.18215
monitor: val/loss_simple_ema
embedder_config:
target: ldm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
noise_aug_config:
target: ldm.modules.encoders.modules.CLIPEmbeddingNoiseAugmentation
params:
timestep_dim: 1024
noise_schedule_config:
timesteps: 1000
beta_schedule: squaredcos_cap_v2
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
num_classes: "sequential"
adm_in_channels: 2048
use_checkpoint: True
image_size: 32 # unused
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_head_channels: 64 # need to fix for flash-attn
use_spatial_transformer: True
use_linear_in_transformer: True
transformer_depth: 1
context_dim: 1024
spatial_transformer_attn_type: "softmax-xformers"
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
attn_type: "vanilla-xformers"
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: [ ]
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
params:
freeze: True
layer: "penultimate"

View file

@ -307,6 +307,15 @@ def model_wrapper(
else: else:
x_in = torch.cat([x] * 2) x_in = torch.cat([x] * 2)
t_in = torch.cat([t_continuous] * 2) t_in = torch.cat([t_continuous] * 2)
if isinstance(condition, dict):
assert isinstance(unconditional_conditioning, dict)
c_in = dict()
for k in c:
if isinstance(c[k], list):
c_in[k] = [torch.cat([unconditional_conditioning[k][i], c[k][i]]) for i in range(len(c[k]))]
else:
c_in[k] = torch.cat([unconditional_conditioning[k], c[k]])
else:
c_in = torch.cat([unconditional_condition, condition]) c_in = torch.cat([unconditional_condition, condition])
noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2) noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
return noise_uncond + guidance_scale * (noise - noise_uncond) return noise_uncond + guidance_scale * (noise - noise_uncond)

View file

@ -6,7 +6,7 @@ from torch.utils.checkpoint import checkpoint
from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel
import open_clip import open_clip
from ldm.util import default, count_params from ldm.util import default, count_params, autocast
class AbstractEncoder(nn.Module): class AbstractEncoder(nn.Module):
@ -232,6 +232,63 @@ class FrozenOpenCLIPEmbedder(AbstractEncoder):
return self(text) return self(text)
class FrozenOpenCLIPImageEmbedder(AbstractEncoder):
"""
Uses the OpenCLIP vision transformer encoder for images
"""
def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77,
freeze=True, layer="pooled", antialias=True, ucg_rate=0.):
super().__init__()
model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'),
pretrained=version, )
del model.transformer
self.model = model
self.device = device
self.max_length = max_length
if freeze:
self.freeze()
self.layer = layer
if self.layer == "penultimate":
raise NotImplementedError()
self.layer_idx = 1
self.antialias = antialias
self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
self.ucg_rate = ucg_rate
def preprocess(self, x):
# normalize to [0,1]
x = kornia.geometry.resize(x, (224, 224),
interpolation='bicubic', align_corners=True,
antialias=self.antialias)
x = (x + 1.) / 2.
# renormalize according to clip
x = kornia.enhance.normalize(x, self.mean, self.std)
return x
def freeze(self):
self.model = self.model.eval()
for param in self.parameters():
param.requires_grad = False
@autocast
def forward(self, image, no_dropout=False):
z = self.encode_with_vision_transformer(image)
if self.ucg_rate > 0. and not no_dropout:
z = torch.bernoulli((1.-self.ucg_rate) * torch.ones(z.shape[0], device=z.device))[:, None] * z
return z
def encode_with_vision_transformer(self, img):
img = self.preprocess(img)
x = self.model.visual(img)
return x
def encode(self, text):
return self(text)
class FrozenCLIPT5Encoder(AbstractEncoder): class FrozenCLIPT5Encoder(AbstractEncoder):
def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device="cuda", def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device="cuda",
clip_max_length=77, t5_max_length=77): clip_max_length=77, t5_max_length=77):

View file

@ -8,6 +8,16 @@ from inspect import isfunction
from PIL import Image, ImageDraw, ImageFont from PIL import Image, ImageDraw, ImageFont
def autocast(f):
def do_autocast(*args, **kwargs):
with torch.cuda.amp.autocast(enabled=True,
dtype=torch.get_autocast_gpu_dtype(),
cache_enabled=torch.is_autocast_cache_enabled()):
return f(*args, **kwargs)
return do_autocast
def log_txt_as_img(wh, xc, size=10): def log_txt_as_img(wh, xc, size=10):
# wh a tuple of (width, height) # wh a tuple of (width, height)
# xc a list of captions to plot # xc a list of captions to plot

View file

@ -22,10 +22,11 @@ from ldm.models.diffusion.dpm_solver import DPMSolverSampler
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
PROMPTS_ROOT = "scripts/prompts/" PROMPTS_ROOT = "scripts/prompts/"
SAVE_PATH = "outputs/demo/stable-karlo/" SAVE_PATH = "outputs/demo/stable-unclip/"
VERSION2SPECS = { VERSION2SPECS = {
"Stable Karlo": {"H": 768, "W": 768, "C": 4, "f": 8}, "Stable unCLIP-L": {"H": 768, "W": 768, "C": 4, "f": 8},
"Stable unOpenCLIP-H": {"H": 768, "W": 768, "C": 4, "f": 8},
"Full Karlo": {} "Full Karlo": {}
} }
@ -193,12 +194,16 @@ def torch2np(x):
@st.cache(allow_output_mutation=True, suppress_st_warning=True) @st.cache(allow_output_mutation=True, suppress_st_warning=True)
def init(version="Stable Karlo", load_karlo_prior=False): def init(version="Stable unCLIP-L", load_karlo_prior=False):
state = dict() state = dict()
if not "model" in state: if not "model" in state:
if version == "Stable Karlo": if version == "Stable unCLIP-L":
config = "configs/stable-diffusion/v2-1-stable-karlo-inference.yaml" config = "configs/stable-diffusion/v2-1-stable-unclip-l-inference.yaml"
ckpt = "checkpoints/v2-1-stable-unclip-ft.ckpt" ckpt = "checkpoints/v2-1-stable-unclip-l-ft.ckpt"
elif version == "Stable unOpenCLIP-H":
config = "configs/stable-diffusion/v2-1-stable-unclip-h-inference.yaml"
ckpt = "checkpoints/v2-1-stable-unclip-h-ft.ckpt"
elif version == "Full Karlo": elif version == "Full Karlo":
from ldm.modules.karlo.kakao.sampler import T2ISampler from ldm.modules.karlo.kakao.sampler import T2ISampler
@ -223,7 +228,7 @@ def init(version="Stable Karlo", load_karlo_prior=False):
from ldm.modules.karlo.kakao.sampler import PriorSampler from ldm.modules.karlo.kakao.sampler import PriorSampler
st.info("Loading KARLO CLIP prior...") st.info("Loading KARLO CLIP prior...")
karlo_prior = PriorSampler.from_pretrained( karlo_prior = PriorSampler.from_pretrained(
root_dir="/fsx/robin/checkpoints/karlo_models", root_dir="checkpoints/karlo_models",
clip_model_path="ViT-L-14.pt", clip_model_path="ViT-L-14.pt",
clip_stat_path="ViT-L-14_stats.th", clip_stat_path="ViT-L-14_stats.th",
sampling_type="default", sampling_type="default",
@ -266,10 +271,10 @@ def load_model_from_config(config, ckpt, verbose=False, vae_sd=None):
if __name__ == "__main__": if __name__ == "__main__":
st.title("Stable Karlo") st.title("Stable unCLIP")
mode = "txt2img" mode = "txt2img"
version = st.selectbox("Model Version", list(VERSION2SPECS.keys()), 0) version = st.selectbox("Model Version", list(VERSION2SPECS.keys()), 0)
use_karlo = st.checkbox("Use KARLO prior", False) use_karlo = st.checkbox("Use KARLO prior", False) and version in ["Stable unCLIP-L"]
state = init(version=version, load_karlo_prior=use_karlo) state = init(version=version, load_karlo_prior=use_karlo)
st.info(state["msg"]) st.info(state["msg"])
prompt = st.text_input("Prompt", "a professional photograph of an astronaut riding a horse") prompt = st.text_input("Prompt", "a professional photograph of an astronaut riding a horse")