mirror of
https://github.com/Stability-AI/stablediffusion.git
synced 2024-12-21 23:24:59 +00:00
make it work
This commit is contained in:
parent
5ca06055d4
commit
cddd65d51f
6 changed files with 34 additions and 18 deletions
35
README.md
35
README.md
|
@ -149,12 +149,33 @@ We provide two models, trained on OpenAI CLIP-L and OpenCLIP-H image embeddings,
|
|||
_[TODO: +++prelim private upload on HF+++]_ from [https://huggingface.co/stabilityai/stable-unclip-preview](https://huggingface.co/stabilityai/stable-unclip-preview).
|
||||
To use them, download from Hugging Face, and put and the weights into the `checkpoints` folder.
|
||||
#### Image Variations
|
||||
![image-variations-h](assets/stable-samples/stable-unclip/castle.jpg)
|
||||
![image-variations-h](assets/stable-samples/stable-unclip/cornmen.jpg)
|
||||
![image-variations-l-1](assets/stable-samples/stable-unclip/houses_out.jpeg)
|
||||
![image-variations-l-2](assets/stable-samples/stable-unclip/plates_out.jpeg)
|
||||
|
||||
_++TODO: Input images from the DIV2K dataset. Proceed with care++_
|
||||
_++TODO: Input images from the DIV2K dataset. check license++_
|
||||
|
||||
#### Stable Diffusion Meets Karlo
|
||||
Run
|
||||
|
||||
```
|
||||
streamlit run scripts/streamlit/stableunclip.py
|
||||
```
|
||||
to launch a streamlit script than can be used to make image variations with both models (CLIP-L and OpenCLIP-H).
|
||||
These models can process a `noise_level`, which specifies an amount of Gaussian noise added to the CLIP embeddings.
|
||||
This can be used to increase output variance as in the following examples.
|
||||
|
||||
**noise_level = 0**
|
||||
![image-variations-l-3](assets/stable-samples/stable-unclip/oldcar000.jpeg)
|
||||
|
||||
**noise_level = 500**
|
||||
![image-variations-l-4](assets/stable-samples/stable-unclip/oldcar500.jpeg)
|
||||
|
||||
**noise_level = 800**
|
||||
![image-variations-l-6](assets/stable-samples/stable-unclip/oldcar800.jpeg)
|
||||
|
||||
|
||||
|
||||
|
||||
### Stable Diffusion Meets Karlo
|
||||
![panda](assets/stable-samples/stable-unclip/panda.jpg)
|
||||
|
||||
Recently, [KakaoBrain](https://kakaobrain.com/) openly released [Karlo](https://github.com/kakaobrain/karlo), a pretrained, large-scale replication of [unCLIP](https://arxiv.org/abs/2204.06125).
|
||||
|
@ -174,10 +195,10 @@ and the finetuned SD2.1 unCLIP-L checkpoint _[TODO: +++prelim private upload on
|
|||
Then, run
|
||||
|
||||
```
|
||||
streamlit run scripts/streamlit/stablekarlo.py
|
||||
streamlit run scripts/streamlit/stableunclip.py
|
||||
```
|
||||
|
||||
The script optionally supports sampling from the full Karlo model. To do so, you need to download the 64x64 decoder and 64->256 upscaler
|
||||
and pick the `use_karlo` option in the GUI.
|
||||
The script optionally supports sampling from the full Karlo model. To use it, download the 64x64 decoder and 64->256 upscaler
|
||||
via
|
||||
```shell
|
||||
cd checkpoints/karlo_models
|
||||
|
|
|
@ -46,7 +46,6 @@ model:
|
|||
use_linear_in_transformer: True
|
||||
transformer_depth: 1
|
||||
context_dim: 1024
|
||||
spatial_transformer_attn_type: "softmax-xformers"
|
||||
legacy: False
|
||||
|
||||
first_stage_config:
|
||||
|
|
|
@ -50,7 +50,6 @@ model:
|
|||
use_linear_in_transformer: True
|
||||
transformer_depth: 1
|
||||
context_dim: 1024
|
||||
spatial_transformer_attn_type: "softmax-xformers"
|
||||
legacy: False
|
||||
|
||||
first_stage_config:
|
||||
|
|
|
@ -225,6 +225,7 @@ class GroupNorm32(nn.GroupNorm):
|
|||
def forward(self, x):
|
||||
return super().forward(x.float()).type(x.dtype)
|
||||
|
||||
|
||||
def conv_nd(dims, *args, **kwargs):
|
||||
"""
|
||||
Create a 1D, 2D, or 3D convolution module.
|
||||
|
@ -274,4 +275,4 @@ class HybridConditioner(nn.Module):
|
|||
def noise_like(shape, device, repeat=False):
|
||||
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
|
||||
noise = lambda: torch.randn(shape, device=device)
|
||||
return repeat_noise() if repeat else noise()
|
||||
return repeat_noise() if repeat else noise()
|
||||
|
|
|
@ -132,7 +132,6 @@ class FrozenCLIPEmbedder(AbstractEncoder):
|
|||
return self(text)
|
||||
|
||||
|
||||
from clip import load as load_clip
|
||||
class ClipImageEmbedder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -143,6 +142,7 @@ class ClipImageEmbedder(nn.Module):
|
|||
ucg_rate=0.
|
||||
):
|
||||
super().__init__()
|
||||
from clip import load as load_clip
|
||||
self.model, _ = load_clip(name=model, device=device, jit=jit)
|
||||
|
||||
self.antialias = antialias
|
||||
|
|
|
@ -298,15 +298,11 @@ if __name__ == "__main__":
|
|||
seed_everything(seed)
|
||||
|
||||
ucg_schedule = None
|
||||
sampler = st.sidebar.selectbox("Sampler", ["DDIM", "PLMS", "DPM"], 0)
|
||||
sampler = st.sidebar.selectbox("Sampler", ["DDIM", "DPM"], 0)
|
||||
if version == "Full Karlo":
|
||||
pass
|
||||
else:
|
||||
if sampler == "PLMS":
|
||||
st.warning("NOTE: Some models (such as v-pred) currently only support DDIM/DPM sampling here")
|
||||
sampler = PLMSSampler(state["model"])
|
||||
elif sampler == "DPM":
|
||||
st.warning("NOTE: Using DPM sampler with default sampling parameters (DPM-2)")
|
||||
if sampler == "DPM":
|
||||
sampler = DPMSolverSampler(state["model"])
|
||||
elif sampler == "DDIM":
|
||||
sampler = DDIMSampler(state["model"])
|
||||
|
@ -342,7 +338,6 @@ if __name__ == "__main__":
|
|||
init_img = get_init_img(batch_size=number_cols)
|
||||
with torch.no_grad():
|
||||
adm_cond = state["model"].embedder(init_img)
|
||||
adm_uc = torch.zeros_like(adm_cond)
|
||||
if state["model"].noise_augmentor is not None:
|
||||
noise_level = st.number_input("Noise Augmentation for CLIP embeddings", min_value=0,
|
||||
max_value=state["model"].noise_augmentor.max_noise_level - 1, value=0)
|
||||
|
@ -350,6 +345,7 @@ if __name__ == "__main__":
|
|||
torch.tensor([noise_level]).to(state["model"].device), '1 -> b', b=number_cols))
|
||||
# assume this gives embeddings of noise levels
|
||||
adm_cond = torch.cat((c_adm, noise_level_emb), 1)
|
||||
adm_uc = torch.zeros_like(adm_cond)
|
||||
|
||||
if st.button("Sample"):
|
||||
print("running prompt:", prompt)
|
||||
|
|
Loading…
Reference in a new issue