mirror of
https://github.com/Stability-AI/stablediffusion.git
synced 2024-12-22 07:34:58 +00:00
update for openclip release
This commit is contained in:
parent
639b3f3f01
commit
5ca06055d4
8 changed files with 195 additions and 20 deletions
31
README.md
31
README.md
|
@ -137,15 +137,28 @@ Note: The inference config for all model versions is designed to be used with EM
|
|||
For this reason `use_ema=False` is set in the configuration, otherwise the code will try to switch from
|
||||
non-EMA to EMA weights.
|
||||
|
||||
### Stable Diffusion Meets Karlo
|
||||
![upscaling-x4](assets/stable-samples/stable-unclip/panda.jpg)
|
||||
_++++++ NOTE: preliminary checkpoint for internal testing ++++++_
|
||||
### Stable unCLIP
|
||||
_++++++ NOTE: preliminary checkpoints for internal testing ++++++_
|
||||
|
||||
Recently, [KakaoBrain](https://kakaobrain.com/) openly released [Karlo](https://github.com/kakaobrain/karlo), a pretrained, large-scale replication of [unCLIP](https://arxiv.org/abs/2204.06125) (also known as DALL·E 2).
|
||||
We introduce _Stable Karlo_, a combination of the Karlo CLIP image embedding prior, and Stable Diffusion v2.1.
|
||||
More precisely, we finetuned SD 2.1 to accept a CLIP ViT-L/14 image embedding in addition to the text encodings.
|
||||
This means that the model can be used to produce image variations in the style of unCLIP, but can also be combined with the
|
||||
embedding prior of KARLO and directly decodes to 768x768 pixel resolution.
|
||||
[unCLIP](https://openai.com/dall-e-2/) is the approach behind OpenAI's [DALL·E 2](https://openai.com/dall-e-2/),
|
||||
trained to invert CLIP image embeddings.
|
||||
We finetuned SD 2.1 to accept a CLIP ViT-L/14 image embedding in addition to the text encodings.
|
||||
This means that the model can be used to produce image variations, but can also be combined with a text-to-image
|
||||
embedding prior to yield a full text-to-image model at 768x768 resolution.
|
||||
We provide two models, trained on OpenAI CLIP-L and OpenCLIP-H image embeddings, respectively, available
|
||||
_[TODO: +++prelim private upload on HF+++]_ from [https://huggingface.co/stabilityai/stable-unclip-preview](https://huggingface.co/stabilityai/stable-unclip-preview).
|
||||
To use them, download from Hugging Face, and put and the weights into the `checkpoints` folder.
|
||||
#### Image Variations
|
||||
![image-variations-h](assets/stable-samples/stable-unclip/castle.jpg)
|
||||
![image-variations-h](assets/stable-samples/stable-unclip/cornmen.jpg)
|
||||
|
||||
_++TODO: Input images from the DIV2K dataset. Proceed with care++_
|
||||
|
||||
#### Stable Diffusion Meets Karlo
|
||||
![panda](assets/stable-samples/stable-unclip/panda.jpg)
|
||||
|
||||
Recently, [KakaoBrain](https://kakaobrain.com/) openly released [Karlo](https://github.com/kakaobrain/karlo), a pretrained, large-scale replication of [unCLIP](https://arxiv.org/abs/2204.06125).
|
||||
We introduce _Stable Karlo_, a combination of the Karlo CLIP image embedding prior, and Stable Diffusion v2.1-768.
|
||||
|
||||
To run the model, first download the KARLO checkpoints
|
||||
```shell
|
||||
|
@ -156,7 +169,7 @@ wget https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/0b623
|
|||
wget https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/85626483eaca9f581e2a78d31ff905ca/prior-ckpt-step%3D01000000-of-01000000.ckpt
|
||||
cd ../../
|
||||
```
|
||||
and the finetuned SD2.1 checkpoint [+++prelim private upload on HF+++] from [https://huggingface.co/stabilityai/stable-unclip-preview](https://huggingface.co/stabilityai/stable-unclip-preview), and put the ckpt into the `checkpoints folder`
|
||||
and the finetuned SD2.1 unCLIP-L checkpoint _[TODO: +++prelim private upload on HF+++]_ from [https://huggingface.co/stabilityai/stable-unclip-preview](https://huggingface.co/stabilityai/stable-unclip-preview), and put the ckpt into the `checkpoints folder`
|
||||
|
||||
Then, run
|
||||
|
||||
|
|
1
checkpoints/checkpoints.txt
Normal file
1
checkpoints/checkpoints.txt
Normal file
|
@ -0,0 +1 @@
|
|||
Put unCLIP checkpoints here.
|
80
configs/stable-diffusion/v2-1-stable-unclip-h-inference.yaml
Normal file
80
configs/stable-diffusion/v2-1-stable-unclip-h-inference.yaml
Normal file
|
@ -0,0 +1,80 @@
|
|||
model:
|
||||
base_learning_rate: 1.0e-04
|
||||
target: ldm.models.diffusion.ddpm.ImageEmbeddingConditionedLatentDiffusion
|
||||
params:
|
||||
embedding_dropout: 0.25
|
||||
parameterization: "v"
|
||||
linear_start: 0.00085
|
||||
linear_end: 0.0120
|
||||
log_every_t: 200
|
||||
timesteps: 1000
|
||||
first_stage_key: "jpg"
|
||||
cond_stage_key: "txt"
|
||||
image_size: 96
|
||||
channels: 4
|
||||
cond_stage_trainable: false
|
||||
conditioning_key: crossattn-adm
|
||||
scale_factor: 0.18215
|
||||
monitor: val/loss_simple_ema
|
||||
|
||||
embedder_config:
|
||||
target: ldm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
|
||||
|
||||
noise_aug_config:
|
||||
target: ldm.modules.encoders.modules.CLIPEmbeddingNoiseAugmentation
|
||||
params:
|
||||
timestep_dim: 1024
|
||||
noise_schedule_config:
|
||||
timesteps: 1000
|
||||
beta_schedule: squaredcos_cap_v2
|
||||
|
||||
unet_config:
|
||||
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||
params:
|
||||
num_classes: "sequential"
|
||||
adm_in_channels: 2048
|
||||
use_checkpoint: True
|
||||
image_size: 32 # unused
|
||||
in_channels: 4
|
||||
out_channels: 4
|
||||
model_channels: 320
|
||||
attention_resolutions: [ 4, 2, 1 ]
|
||||
num_res_blocks: 2
|
||||
channel_mult: [ 1, 2, 4, 4 ]
|
||||
num_head_channels: 64 # need to fix for flash-attn
|
||||
use_spatial_transformer: True
|
||||
use_linear_in_transformer: True
|
||||
transformer_depth: 1
|
||||
context_dim: 1024
|
||||
spatial_transformer_attn_type: "softmax-xformers"
|
||||
legacy: False
|
||||
|
||||
first_stage_config:
|
||||
target: ldm.models.autoencoder.AutoencoderKL
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
attn_type: "vanilla-xformers"
|
||||
double_z: true
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult:
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
- 4
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: [ ]
|
||||
dropout: 0.0
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
|
||||
cond_stage_config:
|
||||
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
|
||||
params:
|
||||
freeze: True
|
||||
layer: "penultimate"
|
|
@ -307,7 +307,16 @@ def model_wrapper(
|
|||
else:
|
||||
x_in = torch.cat([x] * 2)
|
||||
t_in = torch.cat([t_continuous] * 2)
|
||||
c_in = torch.cat([unconditional_condition, condition])
|
||||
if isinstance(condition, dict):
|
||||
assert isinstance(unconditional_conditioning, dict)
|
||||
c_in = dict()
|
||||
for k in c:
|
||||
if isinstance(c[k], list):
|
||||
c_in[k] = [torch.cat([unconditional_conditioning[k][i], c[k][i]]) for i in range(len(c[k]))]
|
||||
else:
|
||||
c_in[k] = torch.cat([unconditional_conditioning[k], c[k]])
|
||||
else:
|
||||
c_in = torch.cat([unconditional_condition, condition])
|
||||
noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
|
||||
return noise_uncond + guidance_scale * (noise - noise_uncond)
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@ from torch.utils.checkpoint import checkpoint
|
|||
from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel
|
||||
|
||||
import open_clip
|
||||
from ldm.util import default, count_params
|
||||
from ldm.util import default, count_params, autocast
|
||||
|
||||
|
||||
class AbstractEncoder(nn.Module):
|
||||
|
@ -232,6 +232,63 @@ class FrozenOpenCLIPEmbedder(AbstractEncoder):
|
|||
return self(text)
|
||||
|
||||
|
||||
class FrozenOpenCLIPImageEmbedder(AbstractEncoder):
|
||||
"""
|
||||
Uses the OpenCLIP vision transformer encoder for images
|
||||
"""
|
||||
def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77,
|
||||
freeze=True, layer="pooled", antialias=True, ucg_rate=0.):
|
||||
super().__init__()
|
||||
model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'),
|
||||
pretrained=version, )
|
||||
del model.transformer
|
||||
self.model = model
|
||||
|
||||
self.device = device
|
||||
self.max_length = max_length
|
||||
if freeze:
|
||||
self.freeze()
|
||||
self.layer = layer
|
||||
if self.layer == "penultimate":
|
||||
raise NotImplementedError()
|
||||
self.layer_idx = 1
|
||||
|
||||
self.antialias = antialias
|
||||
|
||||
self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
|
||||
self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
|
||||
self.ucg_rate = ucg_rate
|
||||
|
||||
def preprocess(self, x):
|
||||
# normalize to [0,1]
|
||||
x = kornia.geometry.resize(x, (224, 224),
|
||||
interpolation='bicubic', align_corners=True,
|
||||
antialias=self.antialias)
|
||||
x = (x + 1.) / 2.
|
||||
# renormalize according to clip
|
||||
x = kornia.enhance.normalize(x, self.mean, self.std)
|
||||
return x
|
||||
|
||||
def freeze(self):
|
||||
self.model = self.model.eval()
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
@autocast
|
||||
def forward(self, image, no_dropout=False):
|
||||
z = self.encode_with_vision_transformer(image)
|
||||
if self.ucg_rate > 0. and not no_dropout:
|
||||
z = torch.bernoulli((1.-self.ucg_rate) * torch.ones(z.shape[0], device=z.device))[:, None] * z
|
||||
return z
|
||||
|
||||
def encode_with_vision_transformer(self, img):
|
||||
img = self.preprocess(img)
|
||||
x = self.model.visual(img)
|
||||
return x
|
||||
|
||||
def encode(self, text):
|
||||
return self(text)
|
||||
|
||||
class FrozenCLIPT5Encoder(AbstractEncoder):
|
||||
def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device="cuda",
|
||||
clip_max_length=77, t5_max_length=77):
|
||||
|
|
10
ldm/util.py
10
ldm/util.py
|
@ -8,6 +8,16 @@ from inspect import isfunction
|
|||
from PIL import Image, ImageDraw, ImageFont
|
||||
|
||||
|
||||
def autocast(f):
|
||||
def do_autocast(*args, **kwargs):
|
||||
with torch.cuda.amp.autocast(enabled=True,
|
||||
dtype=torch.get_autocast_gpu_dtype(),
|
||||
cache_enabled=torch.is_autocast_cache_enabled()):
|
||||
return f(*args, **kwargs)
|
||||
|
||||
return do_autocast
|
||||
|
||||
|
||||
def log_txt_as_img(wh, xc, size=10):
|
||||
# wh a tuple of (width, height)
|
||||
# xc a list of captions to plot
|
||||
|
|
|
@ -22,10 +22,11 @@ from ldm.models.diffusion.dpm_solver import DPMSolverSampler
|
|||
torch.set_grad_enabled(False)
|
||||
|
||||
PROMPTS_ROOT = "scripts/prompts/"
|
||||
SAVE_PATH = "outputs/demo/stable-karlo/"
|
||||
SAVE_PATH = "outputs/demo/stable-unclip/"
|
||||
|
||||
VERSION2SPECS = {
|
||||
"Stable Karlo": {"H": 768, "W": 768, "C": 4, "f": 8},
|
||||
"Stable unCLIP-L": {"H": 768, "W": 768, "C": 4, "f": 8},
|
||||
"Stable unOpenCLIP-H": {"H": 768, "W": 768, "C": 4, "f": 8},
|
||||
"Full Karlo": {}
|
||||
}
|
||||
|
||||
|
@ -193,12 +194,16 @@ def torch2np(x):
|
|||
|
||||
|
||||
@st.cache(allow_output_mutation=True, suppress_st_warning=True)
|
||||
def init(version="Stable Karlo", load_karlo_prior=False):
|
||||
def init(version="Stable unCLIP-L", load_karlo_prior=False):
|
||||
state = dict()
|
||||
if not "model" in state:
|
||||
if version == "Stable Karlo":
|
||||
config = "configs/stable-diffusion/v2-1-stable-karlo-inference.yaml"
|
||||
ckpt = "checkpoints/v2-1-stable-unclip-ft.ckpt"
|
||||
if version == "Stable unCLIP-L":
|
||||
config = "configs/stable-diffusion/v2-1-stable-unclip-l-inference.yaml"
|
||||
ckpt = "checkpoints/v2-1-stable-unclip-l-ft.ckpt"
|
||||
|
||||
elif version == "Stable unOpenCLIP-H":
|
||||
config = "configs/stable-diffusion/v2-1-stable-unclip-h-inference.yaml"
|
||||
ckpt = "checkpoints/v2-1-stable-unclip-h-ft.ckpt"
|
||||
|
||||
elif version == "Full Karlo":
|
||||
from ldm.modules.karlo.kakao.sampler import T2ISampler
|
||||
|
@ -223,7 +228,7 @@ def init(version="Stable Karlo", load_karlo_prior=False):
|
|||
from ldm.modules.karlo.kakao.sampler import PriorSampler
|
||||
st.info("Loading KARLO CLIP prior...")
|
||||
karlo_prior = PriorSampler.from_pretrained(
|
||||
root_dir="/fsx/robin/checkpoints/karlo_models",
|
||||
root_dir="checkpoints/karlo_models",
|
||||
clip_model_path="ViT-L-14.pt",
|
||||
clip_stat_path="ViT-L-14_stats.th",
|
||||
sampling_type="default",
|
||||
|
@ -266,10 +271,10 @@ def load_model_from_config(config, ckpt, verbose=False, vae_sd=None):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
st.title("Stable Karlo")
|
||||
st.title("Stable unCLIP")
|
||||
mode = "txt2img"
|
||||
version = st.selectbox("Model Version", list(VERSION2SPECS.keys()), 0)
|
||||
use_karlo = st.checkbox("Use KARLO prior", False)
|
||||
use_karlo = st.checkbox("Use KARLO prior", False) and version in ["Stable unCLIP-L"]
|
||||
state = init(version=version, load_karlo_prior=use_karlo)
|
||||
st.info(state["msg"])
|
||||
prompt = st.text_input("Prompt", "a professional photograph of an astronaut riding a horse")
|
Loading…
Reference in a new issue