make it work

This commit is contained in:
Robin Rombach 2023-01-29 23:21:30 +01:00
parent 5ca06055d4
commit cddd65d51f
6 changed files with 34 additions and 18 deletions

View file

@ -149,12 +149,33 @@ We provide two models, trained on OpenAI CLIP-L and OpenCLIP-H image embeddings,
_[TODO: +++prelim private upload on HF+++]_ from [https://huggingface.co/stabilityai/stable-unclip-preview](https://huggingface.co/stabilityai/stable-unclip-preview). _[TODO: +++prelim private upload on HF+++]_ from [https://huggingface.co/stabilityai/stable-unclip-preview](https://huggingface.co/stabilityai/stable-unclip-preview).
To use them, download from Hugging Face, and put and the weights into the `checkpoints` folder. To use them, download from Hugging Face, and put and the weights into the `checkpoints` folder.
#### Image Variations #### Image Variations
![image-variations-h](assets/stable-samples/stable-unclip/castle.jpg) ![image-variations-l-1](assets/stable-samples/stable-unclip/houses_out.jpeg)
![image-variations-h](assets/stable-samples/stable-unclip/cornmen.jpg) ![image-variations-l-2](assets/stable-samples/stable-unclip/plates_out.jpeg)
_++TODO: Input images from the DIV2K dataset. Proceed with care++_ _++TODO: Input images from the DIV2K dataset. check license++_
#### Stable Diffusion Meets Karlo Run
```
streamlit run scripts/streamlit/stableunclip.py
```
to launch a streamlit script than can be used to make image variations with both models (CLIP-L and OpenCLIP-H).
These models can process a `noise_level`, which specifies an amount of Gaussian noise added to the CLIP embeddings.
This can be used to increase output variance as in the following examples.
**noise_level = 0**
![image-variations-l-3](assets/stable-samples/stable-unclip/oldcar000.jpeg)
**noise_level = 500**
![image-variations-l-4](assets/stable-samples/stable-unclip/oldcar500.jpeg)
**noise_level = 800**
![image-variations-l-6](assets/stable-samples/stable-unclip/oldcar800.jpeg)
### Stable Diffusion Meets Karlo
![panda](assets/stable-samples/stable-unclip/panda.jpg) ![panda](assets/stable-samples/stable-unclip/panda.jpg)
Recently, [KakaoBrain](https://kakaobrain.com/) openly released [Karlo](https://github.com/kakaobrain/karlo), a pretrained, large-scale replication of [unCLIP](https://arxiv.org/abs/2204.06125). Recently, [KakaoBrain](https://kakaobrain.com/) openly released [Karlo](https://github.com/kakaobrain/karlo), a pretrained, large-scale replication of [unCLIP](https://arxiv.org/abs/2204.06125).
@ -174,10 +195,10 @@ and the finetuned SD2.1 unCLIP-L checkpoint _[TODO: +++prelim private upload on
Then, run Then, run
``` ```
streamlit run scripts/streamlit/stablekarlo.py streamlit run scripts/streamlit/stableunclip.py
``` ```
and pick the `use_karlo` option in the GUI.
The script optionally supports sampling from the full Karlo model. To do so, you need to download the 64x64 decoder and 64->256 upscaler The script optionally supports sampling from the full Karlo model. To use it, download the 64x64 decoder and 64->256 upscaler
via via
```shell ```shell
cd checkpoints/karlo_models cd checkpoints/karlo_models

View file

@ -46,7 +46,6 @@ model:
use_linear_in_transformer: True use_linear_in_transformer: True
transformer_depth: 1 transformer_depth: 1
context_dim: 1024 context_dim: 1024
spatial_transformer_attn_type: "softmax-xformers"
legacy: False legacy: False
first_stage_config: first_stage_config:

View file

@ -50,7 +50,6 @@ model:
use_linear_in_transformer: True use_linear_in_transformer: True
transformer_depth: 1 transformer_depth: 1
context_dim: 1024 context_dim: 1024
spatial_transformer_attn_type: "softmax-xformers"
legacy: False legacy: False
first_stage_config: first_stage_config:

View file

@ -225,6 +225,7 @@ class GroupNorm32(nn.GroupNorm):
def forward(self, x): def forward(self, x):
return super().forward(x.float()).type(x.dtype) return super().forward(x.float()).type(x.dtype)
def conv_nd(dims, *args, **kwargs): def conv_nd(dims, *args, **kwargs):
""" """
Create a 1D, 2D, or 3D convolution module. Create a 1D, 2D, or 3D convolution module.

View file

@ -132,7 +132,6 @@ class FrozenCLIPEmbedder(AbstractEncoder):
return self(text) return self(text)
from clip import load as load_clip
class ClipImageEmbedder(nn.Module): class ClipImageEmbedder(nn.Module):
def __init__( def __init__(
self, self,
@ -143,6 +142,7 @@ class ClipImageEmbedder(nn.Module):
ucg_rate=0. ucg_rate=0.
): ):
super().__init__() super().__init__()
from clip import load as load_clip
self.model, _ = load_clip(name=model, device=device, jit=jit) self.model, _ = load_clip(name=model, device=device, jit=jit)
self.antialias = antialias self.antialias = antialias

View file

@ -298,15 +298,11 @@ if __name__ == "__main__":
seed_everything(seed) seed_everything(seed)
ucg_schedule = None ucg_schedule = None
sampler = st.sidebar.selectbox("Sampler", ["DDIM", "PLMS", "DPM"], 0) sampler = st.sidebar.selectbox("Sampler", ["DDIM", "DPM"], 0)
if version == "Full Karlo": if version == "Full Karlo":
pass pass
else: else:
if sampler == "PLMS": if sampler == "DPM":
st.warning("NOTE: Some models (such as v-pred) currently only support DDIM/DPM sampling here")
sampler = PLMSSampler(state["model"])
elif sampler == "DPM":
st.warning("NOTE: Using DPM sampler with default sampling parameters (DPM-2)")
sampler = DPMSolverSampler(state["model"]) sampler = DPMSolverSampler(state["model"])
elif sampler == "DDIM": elif sampler == "DDIM":
sampler = DDIMSampler(state["model"]) sampler = DDIMSampler(state["model"])
@ -342,7 +338,6 @@ if __name__ == "__main__":
init_img = get_init_img(batch_size=number_cols) init_img = get_init_img(batch_size=number_cols)
with torch.no_grad(): with torch.no_grad():
adm_cond = state["model"].embedder(init_img) adm_cond = state["model"].embedder(init_img)
adm_uc = torch.zeros_like(adm_cond)
if state["model"].noise_augmentor is not None: if state["model"].noise_augmentor is not None:
noise_level = st.number_input("Noise Augmentation for CLIP embeddings", min_value=0, noise_level = st.number_input("Noise Augmentation for CLIP embeddings", min_value=0,
max_value=state["model"].noise_augmentor.max_noise_level - 1, value=0) max_value=state["model"].noise_augmentor.max_noise_level - 1, value=0)
@ -350,6 +345,7 @@ if __name__ == "__main__":
torch.tensor([noise_level]).to(state["model"].device), '1 -> b', b=number_cols)) torch.tensor([noise_level]).to(state["model"].device), '1 -> b', b=number_cols))
# assume this gives embeddings of noise levels # assume this gives embeddings of noise levels
adm_cond = torch.cat((c_adm, noise_level_emb), 1) adm_cond = torch.cat((c_adm, noise_level_emb), 1)
adm_uc = torch.zeros_like(adm_cond)
if st.button("Sample"): if st.button("Sample"):
print("running prompt:", prompt) print("running prompt:", prompt)