mirror of
https://github.com/Stability-AI/stablediffusion.git
synced 2024-12-22 15:44:58 +00:00
update for openclip release
This commit is contained in:
parent
639b3f3f01
commit
5ca06055d4
8 changed files with 195 additions and 20 deletions
31
README.md
31
README.md
|
@ -137,15 +137,28 @@ Note: The inference config for all model versions is designed to be used with EM
|
||||||
For this reason `use_ema=False` is set in the configuration, otherwise the code will try to switch from
|
For this reason `use_ema=False` is set in the configuration, otherwise the code will try to switch from
|
||||||
non-EMA to EMA weights.
|
non-EMA to EMA weights.
|
||||||
|
|
||||||
### Stable Diffusion Meets Karlo
|
### Stable unCLIP
|
||||||
![upscaling-x4](assets/stable-samples/stable-unclip/panda.jpg)
|
_++++++ NOTE: preliminary checkpoints for internal testing ++++++_
|
||||||
_++++++ NOTE: preliminary checkpoint for internal testing ++++++_
|
|
||||||
|
|
||||||
Recently, [KakaoBrain](https://kakaobrain.com/) openly released [Karlo](https://github.com/kakaobrain/karlo), a pretrained, large-scale replication of [unCLIP](https://arxiv.org/abs/2204.06125) (also known as DALL·E 2).
|
[unCLIP](https://openai.com/dall-e-2/) is the approach behind OpenAI's [DALL·E 2](https://openai.com/dall-e-2/),
|
||||||
We introduce _Stable Karlo_, a combination of the Karlo CLIP image embedding prior, and Stable Diffusion v2.1.
|
trained to invert CLIP image embeddings.
|
||||||
More precisely, we finetuned SD 2.1 to accept a CLIP ViT-L/14 image embedding in addition to the text encodings.
|
We finetuned SD 2.1 to accept a CLIP ViT-L/14 image embedding in addition to the text encodings.
|
||||||
This means that the model can be used to produce image variations in the style of unCLIP, but can also be combined with the
|
This means that the model can be used to produce image variations, but can also be combined with a text-to-image
|
||||||
embedding prior of KARLO and directly decodes to 768x768 pixel resolution.
|
embedding prior to yield a full text-to-image model at 768x768 resolution.
|
||||||
|
We provide two models, trained on OpenAI CLIP-L and OpenCLIP-H image embeddings, respectively, available
|
||||||
|
_[TODO: +++prelim private upload on HF+++]_ from [https://huggingface.co/stabilityai/stable-unclip-preview](https://huggingface.co/stabilityai/stable-unclip-preview).
|
||||||
|
To use them, download from Hugging Face, and put and the weights into the `checkpoints` folder.
|
||||||
|
#### Image Variations
|
||||||
|
![image-variations-h](assets/stable-samples/stable-unclip/castle.jpg)
|
||||||
|
![image-variations-h](assets/stable-samples/stable-unclip/cornmen.jpg)
|
||||||
|
|
||||||
|
_++TODO: Input images from the DIV2K dataset. Proceed with care++_
|
||||||
|
|
||||||
|
#### Stable Diffusion Meets Karlo
|
||||||
|
![panda](assets/stable-samples/stable-unclip/panda.jpg)
|
||||||
|
|
||||||
|
Recently, [KakaoBrain](https://kakaobrain.com/) openly released [Karlo](https://github.com/kakaobrain/karlo), a pretrained, large-scale replication of [unCLIP](https://arxiv.org/abs/2204.06125).
|
||||||
|
We introduce _Stable Karlo_, a combination of the Karlo CLIP image embedding prior, and Stable Diffusion v2.1-768.
|
||||||
|
|
||||||
To run the model, first download the KARLO checkpoints
|
To run the model, first download the KARLO checkpoints
|
||||||
```shell
|
```shell
|
||||||
|
@ -156,7 +169,7 @@ wget https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/0b623
|
||||||
wget https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/85626483eaca9f581e2a78d31ff905ca/prior-ckpt-step%3D01000000-of-01000000.ckpt
|
wget https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/85626483eaca9f581e2a78d31ff905ca/prior-ckpt-step%3D01000000-of-01000000.ckpt
|
||||||
cd ../../
|
cd ../../
|
||||||
```
|
```
|
||||||
and the finetuned SD2.1 checkpoint [+++prelim private upload on HF+++] from [https://huggingface.co/stabilityai/stable-unclip-preview](https://huggingface.co/stabilityai/stable-unclip-preview), and put the ckpt into the `checkpoints folder`
|
and the finetuned SD2.1 unCLIP-L checkpoint _[TODO: +++prelim private upload on HF+++]_ from [https://huggingface.co/stabilityai/stable-unclip-preview](https://huggingface.co/stabilityai/stable-unclip-preview), and put the ckpt into the `checkpoints folder`
|
||||||
|
|
||||||
Then, run
|
Then, run
|
||||||
|
|
||||||
|
|
1
checkpoints/checkpoints.txt
Normal file
1
checkpoints/checkpoints.txt
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Put unCLIP checkpoints here.
|
80
configs/stable-diffusion/v2-1-stable-unclip-h-inference.yaml
Normal file
80
configs/stable-diffusion/v2-1-stable-unclip-h-inference.yaml
Normal file
|
@ -0,0 +1,80 @@
|
||||||
|
model:
|
||||||
|
base_learning_rate: 1.0e-04
|
||||||
|
target: ldm.models.diffusion.ddpm.ImageEmbeddingConditionedLatentDiffusion
|
||||||
|
params:
|
||||||
|
embedding_dropout: 0.25
|
||||||
|
parameterization: "v"
|
||||||
|
linear_start: 0.00085
|
||||||
|
linear_end: 0.0120
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: "jpg"
|
||||||
|
cond_stage_key: "txt"
|
||||||
|
image_size: 96
|
||||||
|
channels: 4
|
||||||
|
cond_stage_trainable: false
|
||||||
|
conditioning_key: crossattn-adm
|
||||||
|
scale_factor: 0.18215
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
|
||||||
|
embedder_config:
|
||||||
|
target: ldm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
|
||||||
|
|
||||||
|
noise_aug_config:
|
||||||
|
target: ldm.modules.encoders.modules.CLIPEmbeddingNoiseAugmentation
|
||||||
|
params:
|
||||||
|
timestep_dim: 1024
|
||||||
|
noise_schedule_config:
|
||||||
|
timesteps: 1000
|
||||||
|
beta_schedule: squaredcos_cap_v2
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
num_classes: "sequential"
|
||||||
|
adm_in_channels: 2048
|
||||||
|
use_checkpoint: True
|
||||||
|
image_size: 32 # unused
|
||||||
|
in_channels: 4
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 320
|
||||||
|
attention_resolutions: [ 4, 2, 1 ]
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [ 1, 2, 4, 4 ]
|
||||||
|
num_head_channels: 64 # need to fix for flash-attn
|
||||||
|
use_spatial_transformer: True
|
||||||
|
use_linear_in_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 1024
|
||||||
|
spatial_transformer_attn_type: "softmax-xformers"
|
||||||
|
legacy: False
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
attn_type: "vanilla-xformers"
|
||||||
|
double_z: true
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: [ ]
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
|
||||||
|
params:
|
||||||
|
freeze: True
|
||||||
|
layer: "penultimate"
|
|
@ -307,6 +307,15 @@ def model_wrapper(
|
||||||
else:
|
else:
|
||||||
x_in = torch.cat([x] * 2)
|
x_in = torch.cat([x] * 2)
|
||||||
t_in = torch.cat([t_continuous] * 2)
|
t_in = torch.cat([t_continuous] * 2)
|
||||||
|
if isinstance(condition, dict):
|
||||||
|
assert isinstance(unconditional_conditioning, dict)
|
||||||
|
c_in = dict()
|
||||||
|
for k in c:
|
||||||
|
if isinstance(c[k], list):
|
||||||
|
c_in[k] = [torch.cat([unconditional_conditioning[k][i], c[k][i]]) for i in range(len(c[k]))]
|
||||||
|
else:
|
||||||
|
c_in[k] = torch.cat([unconditional_conditioning[k], c[k]])
|
||||||
|
else:
|
||||||
c_in = torch.cat([unconditional_condition, condition])
|
c_in = torch.cat([unconditional_condition, condition])
|
||||||
noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
|
noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
|
||||||
return noise_uncond + guidance_scale * (noise - noise_uncond)
|
return noise_uncond + guidance_scale * (noise - noise_uncond)
|
||||||
|
|
|
@ -6,7 +6,7 @@ from torch.utils.checkpoint import checkpoint
|
||||||
from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel
|
from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel
|
||||||
|
|
||||||
import open_clip
|
import open_clip
|
||||||
from ldm.util import default, count_params
|
from ldm.util import default, count_params, autocast
|
||||||
|
|
||||||
|
|
||||||
class AbstractEncoder(nn.Module):
|
class AbstractEncoder(nn.Module):
|
||||||
|
@ -232,6 +232,63 @@ class FrozenOpenCLIPEmbedder(AbstractEncoder):
|
||||||
return self(text)
|
return self(text)
|
||||||
|
|
||||||
|
|
||||||
|
class FrozenOpenCLIPImageEmbedder(AbstractEncoder):
|
||||||
|
"""
|
||||||
|
Uses the OpenCLIP vision transformer encoder for images
|
||||||
|
"""
|
||||||
|
def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77,
|
||||||
|
freeze=True, layer="pooled", antialias=True, ucg_rate=0.):
|
||||||
|
super().__init__()
|
||||||
|
model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'),
|
||||||
|
pretrained=version, )
|
||||||
|
del model.transformer
|
||||||
|
self.model = model
|
||||||
|
|
||||||
|
self.device = device
|
||||||
|
self.max_length = max_length
|
||||||
|
if freeze:
|
||||||
|
self.freeze()
|
||||||
|
self.layer = layer
|
||||||
|
if self.layer == "penultimate":
|
||||||
|
raise NotImplementedError()
|
||||||
|
self.layer_idx = 1
|
||||||
|
|
||||||
|
self.antialias = antialias
|
||||||
|
|
||||||
|
self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
|
||||||
|
self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
|
||||||
|
self.ucg_rate = ucg_rate
|
||||||
|
|
||||||
|
def preprocess(self, x):
|
||||||
|
# normalize to [0,1]
|
||||||
|
x = kornia.geometry.resize(x, (224, 224),
|
||||||
|
interpolation='bicubic', align_corners=True,
|
||||||
|
antialias=self.antialias)
|
||||||
|
x = (x + 1.) / 2.
|
||||||
|
# renormalize according to clip
|
||||||
|
x = kornia.enhance.normalize(x, self.mean, self.std)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def freeze(self):
|
||||||
|
self.model = self.model.eval()
|
||||||
|
for param in self.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
@autocast
|
||||||
|
def forward(self, image, no_dropout=False):
|
||||||
|
z = self.encode_with_vision_transformer(image)
|
||||||
|
if self.ucg_rate > 0. and not no_dropout:
|
||||||
|
z = torch.bernoulli((1.-self.ucg_rate) * torch.ones(z.shape[0], device=z.device))[:, None] * z
|
||||||
|
return z
|
||||||
|
|
||||||
|
def encode_with_vision_transformer(self, img):
|
||||||
|
img = self.preprocess(img)
|
||||||
|
x = self.model.visual(img)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def encode(self, text):
|
||||||
|
return self(text)
|
||||||
|
|
||||||
class FrozenCLIPT5Encoder(AbstractEncoder):
|
class FrozenCLIPT5Encoder(AbstractEncoder):
|
||||||
def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device="cuda",
|
def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device="cuda",
|
||||||
clip_max_length=77, t5_max_length=77):
|
clip_max_length=77, t5_max_length=77):
|
||||||
|
|
10
ldm/util.py
10
ldm/util.py
|
@ -8,6 +8,16 @@ from inspect import isfunction
|
||||||
from PIL import Image, ImageDraw, ImageFont
|
from PIL import Image, ImageDraw, ImageFont
|
||||||
|
|
||||||
|
|
||||||
|
def autocast(f):
|
||||||
|
def do_autocast(*args, **kwargs):
|
||||||
|
with torch.cuda.amp.autocast(enabled=True,
|
||||||
|
dtype=torch.get_autocast_gpu_dtype(),
|
||||||
|
cache_enabled=torch.is_autocast_cache_enabled()):
|
||||||
|
return f(*args, **kwargs)
|
||||||
|
|
||||||
|
return do_autocast
|
||||||
|
|
||||||
|
|
||||||
def log_txt_as_img(wh, xc, size=10):
|
def log_txt_as_img(wh, xc, size=10):
|
||||||
# wh a tuple of (width, height)
|
# wh a tuple of (width, height)
|
||||||
# xc a list of captions to plot
|
# xc a list of captions to plot
|
||||||
|
|
|
@ -22,10 +22,11 @@ from ldm.models.diffusion.dpm_solver import DPMSolverSampler
|
||||||
torch.set_grad_enabled(False)
|
torch.set_grad_enabled(False)
|
||||||
|
|
||||||
PROMPTS_ROOT = "scripts/prompts/"
|
PROMPTS_ROOT = "scripts/prompts/"
|
||||||
SAVE_PATH = "outputs/demo/stable-karlo/"
|
SAVE_PATH = "outputs/demo/stable-unclip/"
|
||||||
|
|
||||||
VERSION2SPECS = {
|
VERSION2SPECS = {
|
||||||
"Stable Karlo": {"H": 768, "W": 768, "C": 4, "f": 8},
|
"Stable unCLIP-L": {"H": 768, "W": 768, "C": 4, "f": 8},
|
||||||
|
"Stable unOpenCLIP-H": {"H": 768, "W": 768, "C": 4, "f": 8},
|
||||||
"Full Karlo": {}
|
"Full Karlo": {}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -193,12 +194,16 @@ def torch2np(x):
|
||||||
|
|
||||||
|
|
||||||
@st.cache(allow_output_mutation=True, suppress_st_warning=True)
|
@st.cache(allow_output_mutation=True, suppress_st_warning=True)
|
||||||
def init(version="Stable Karlo", load_karlo_prior=False):
|
def init(version="Stable unCLIP-L", load_karlo_prior=False):
|
||||||
state = dict()
|
state = dict()
|
||||||
if not "model" in state:
|
if not "model" in state:
|
||||||
if version == "Stable Karlo":
|
if version == "Stable unCLIP-L":
|
||||||
config = "configs/stable-diffusion/v2-1-stable-karlo-inference.yaml"
|
config = "configs/stable-diffusion/v2-1-stable-unclip-l-inference.yaml"
|
||||||
ckpt = "checkpoints/v2-1-stable-unclip-ft.ckpt"
|
ckpt = "checkpoints/v2-1-stable-unclip-l-ft.ckpt"
|
||||||
|
|
||||||
|
elif version == "Stable unOpenCLIP-H":
|
||||||
|
config = "configs/stable-diffusion/v2-1-stable-unclip-h-inference.yaml"
|
||||||
|
ckpt = "checkpoints/v2-1-stable-unclip-h-ft.ckpt"
|
||||||
|
|
||||||
elif version == "Full Karlo":
|
elif version == "Full Karlo":
|
||||||
from ldm.modules.karlo.kakao.sampler import T2ISampler
|
from ldm.modules.karlo.kakao.sampler import T2ISampler
|
||||||
|
@ -223,7 +228,7 @@ def init(version="Stable Karlo", load_karlo_prior=False):
|
||||||
from ldm.modules.karlo.kakao.sampler import PriorSampler
|
from ldm.modules.karlo.kakao.sampler import PriorSampler
|
||||||
st.info("Loading KARLO CLIP prior...")
|
st.info("Loading KARLO CLIP prior...")
|
||||||
karlo_prior = PriorSampler.from_pretrained(
|
karlo_prior = PriorSampler.from_pretrained(
|
||||||
root_dir="/fsx/robin/checkpoints/karlo_models",
|
root_dir="checkpoints/karlo_models",
|
||||||
clip_model_path="ViT-L-14.pt",
|
clip_model_path="ViT-L-14.pt",
|
||||||
clip_stat_path="ViT-L-14_stats.th",
|
clip_stat_path="ViT-L-14_stats.th",
|
||||||
sampling_type="default",
|
sampling_type="default",
|
||||||
|
@ -266,10 +271,10 @@ def load_model_from_config(config, ckpt, verbose=False, vae_sd=None):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
st.title("Stable Karlo")
|
st.title("Stable unCLIP")
|
||||||
mode = "txt2img"
|
mode = "txt2img"
|
||||||
version = st.selectbox("Model Version", list(VERSION2SPECS.keys()), 0)
|
version = st.selectbox("Model Version", list(VERSION2SPECS.keys()), 0)
|
||||||
use_karlo = st.checkbox("Use KARLO prior", False)
|
use_karlo = st.checkbox("Use KARLO prior", False) and version in ["Stable unCLIP-L"]
|
||||||
state = init(version=version, load_karlo_prior=use_karlo)
|
state = init(version=version, load_karlo_prior=use_karlo)
|
||||||
st.info(state["msg"])
|
st.info(state["msg"])
|
||||||
prompt = st.text_input("Prompt", "a professional photograph of an astronaut riding a horse")
|
prompt = st.text_input("Prompt", "a professional photograph of an astronaut riding a horse")
|
Loading…
Reference in a new issue