mirror of
https://github.com/Stability-AI/stablediffusion.git
synced 2024-12-22 15:44:58 +00:00
add noise-augmented unCLIP
This commit is contained in:
parent
aad6e38a78
commit
8ec7903da2
6 changed files with 76 additions and 4 deletions
|
@ -158,7 +158,7 @@ cd ../../
|
||||||
```
|
```
|
||||||
and the finetuned SD2.1 checkpoint [+++prelim private upload on HF+++] from [https://huggingface.co/stabilityai/stable-unclip-preview](https://huggingface.co/stabilityai/stable-unclip-preview), and put the ckpt into the `checkpoints folder`
|
and the finetuned SD2.1 checkpoint [+++prelim private upload on HF+++] from [https://huggingface.co/stabilityai/stable-unclip-preview](https://huggingface.co/stabilityai/stable-unclip-preview), and put the ckpt into the `checkpoints folder`
|
||||||
|
|
||||||
The, run
|
Then, run
|
||||||
|
|
||||||
```
|
```
|
||||||
streamlit run scripts/streamlit/stablekarlo.py
|
streamlit run scripts/streamlit/stablekarlo.py
|
||||||
|
|
|
@ -23,11 +23,20 @@ model:
|
||||||
params:
|
params:
|
||||||
model: "ViT-L/14"
|
model: "ViT-L/14"
|
||||||
|
|
||||||
|
noise_aug_config:
|
||||||
|
target: ldm.modules.encoders.modules.CLIPEmbeddingNoiseAugmentation
|
||||||
|
params:
|
||||||
|
clip_stats_path: "checkpoints/karlo_models/ViT-L-14_stats.th"
|
||||||
|
timestep_dim: 768
|
||||||
|
noise_schedule_config:
|
||||||
|
timesteps: 1000
|
||||||
|
beta_schedule: squaredcos_cap_v2
|
||||||
|
|
||||||
unet_config:
|
unet_config:
|
||||||
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
params:
|
params:
|
||||||
num_classes: "sequential"
|
num_classes: "sequential"
|
||||||
adm_in_channels: 768
|
adm_in_channels: 1536
|
||||||
use_checkpoint: True
|
use_checkpoint: True
|
||||||
image_size: 32 # unused
|
image_size: 32 # unused
|
||||||
in_channels: 4
|
in_channels: 4
|
||||||
|
|
|
@ -1796,11 +1796,13 @@ class LatentUpscaleFinetuneDiffusion(LatentFinetuneDiffusion):
|
||||||
|
|
||||||
|
|
||||||
class ImageEmbeddingConditionedLatentDiffusion(LatentDiffusion):
|
class ImageEmbeddingConditionedLatentDiffusion(LatentDiffusion):
|
||||||
def __init__(self, embedder_config, embedding_key="jpg", embedding_dropout=0.5, freeze_embedder=True, *args, **kwargs):
|
def __init__(self, embedder_config, embedding_key="jpg", embedding_dropout=0.5,
|
||||||
|
freeze_embedder=True, noise_aug_config=None, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.embed_key = embedding_key
|
self.embed_key = embedding_key
|
||||||
self.embedding_dropout = embedding_dropout
|
self.embedding_dropout = embedding_dropout
|
||||||
self._init_embedder(embedder_config, freeze_embedder)
|
self._init_embedder(embedder_config, freeze_embedder)
|
||||||
|
self._init_noise_aug(noise_aug_config)
|
||||||
|
|
||||||
def _init_embedder(self, config, freeze=True):
|
def _init_embedder(self, config, freeze=True):
|
||||||
embedder = instantiate_from_config(config)
|
embedder = instantiate_from_config(config)
|
||||||
|
@ -1810,12 +1812,27 @@ class ImageEmbeddingConditionedLatentDiffusion(LatentDiffusion):
|
||||||
for param in self.embedder.parameters():
|
for param in self.embedder.parameters():
|
||||||
param.requires_grad = False
|
param.requires_grad = False
|
||||||
|
|
||||||
|
def _init_noise_aug(self, config):
|
||||||
|
if config is not None:
|
||||||
|
# use the KARLO schedule for noise augmentation on CLIP image embeddings
|
||||||
|
noise_augmentor = instantiate_from_config(config)
|
||||||
|
assert isinstance(noise_augmentor, nn.Module)
|
||||||
|
noise_augmentor = noise_augmentor.eval()
|
||||||
|
noise_augmentor.train = disabled_train
|
||||||
|
self.noise_augmentor = noise_augmentor
|
||||||
|
else:
|
||||||
|
self.noise_augmentor = None
|
||||||
|
|
||||||
def get_input(self, batch, k, cond_key=None, bs=None, **kwargs):
|
def get_input(self, batch, k, cond_key=None, bs=None, **kwargs):
|
||||||
outputs = LatentDiffusion.get_input(self, batch, k, bs=bs, **kwargs)
|
outputs = LatentDiffusion.get_input(self, batch, k, bs=bs, **kwargs)
|
||||||
z, c = outputs[0], outputs[1]
|
z, c = outputs[0], outputs[1]
|
||||||
img = batch[self.embed_key][:bs]
|
img = batch[self.embed_key][:bs]
|
||||||
img = rearrange(img, 'b h w c -> b c h w')
|
img = rearrange(img, 'b h w c -> b c h w')
|
||||||
c_adm = self.embedder(img)
|
c_adm = self.embedder(img)
|
||||||
|
if self.noise_augmentor is not None:
|
||||||
|
c_adm, noise_level_emb = self.noise_augmentor(c_adm)
|
||||||
|
# assume this gives embeddings of noise levels
|
||||||
|
c_adm = torch.cat((c_adm, noise_level_emb), 1)
|
||||||
if self.training:
|
if self.training:
|
||||||
c_adm = torch.bernoulli((1. - self.embedding_dropout) * torch.ones(c_adm.shape[0],
|
c_adm = torch.bernoulli((1. - self.embedding_dropout) * torch.ones(c_adm.shape[0],
|
||||||
device=c_adm.device)[:, None]) * c_adm
|
device=c_adm.device)[:, None]) * c_adm
|
||||||
|
|
|
@ -409,6 +409,15 @@ class QKVAttention(nn.Module):
|
||||||
return count_flops_attn(model, _x, y)
|
return count_flops_attn(model, _x, y)
|
||||||
|
|
||||||
|
|
||||||
|
class Timestep(nn.Module):
|
||||||
|
def __init__(self, dim):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
|
||||||
|
def forward(self, t):
|
||||||
|
return timestep_embedding(t, self.dim)
|
||||||
|
|
||||||
|
|
||||||
class UNetModel(nn.Module):
|
class UNetModel(nn.Module):
|
||||||
"""
|
"""
|
||||||
The full UNet model with attention and timestep embedding.
|
The full UNet model with attention and timestep embedding.
|
||||||
|
|
|
@ -34,6 +34,13 @@ def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2,
|
||||||
betas = 1 - alphas[1:] / alphas[:-1]
|
betas = 1 - alphas[1:] / alphas[:-1]
|
||||||
betas = np.clip(betas, a_min=0, a_max=0.999)
|
betas = np.clip(betas, a_min=0, a_max=0.999)
|
||||||
|
|
||||||
|
elif schedule == "squaredcos_cap_v2": # used for karlo prior
|
||||||
|
# return early
|
||||||
|
return betas_for_alpha_bar(
|
||||||
|
n_timestep,
|
||||||
|
lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
|
||||||
|
)
|
||||||
|
|
||||||
elif schedule == "sqrt_linear":
|
elif schedule == "sqrt_linear":
|
||||||
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
|
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
|
||||||
elif schedule == "sqrt":
|
elif schedule == "sqrt":
|
||||||
|
|
|
@ -170,7 +170,6 @@ class ClipImageEmbedder(nn.Module):
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class FrozenOpenCLIPEmbedder(AbstractEncoder):
|
class FrozenOpenCLIPEmbedder(AbstractEncoder):
|
||||||
"""
|
"""
|
||||||
Uses the OpenCLIP transformer encoder for text
|
Uses the OpenCLIP transformer encoder for text
|
||||||
|
@ -251,3 +250,34 @@ class FrozenCLIPT5Encoder(AbstractEncoder):
|
||||||
return [clip_z, t5_z]
|
return [clip_z, t5_z]
|
||||||
|
|
||||||
|
|
||||||
|
from ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
|
||||||
|
from ldm.modules.diffusionmodules.openaimodel import Timestep
|
||||||
|
class CLIPEmbeddingNoiseAugmentation(ImageConcatWithNoiseAugmentation):
|
||||||
|
def __init__(self, *args, clip_stats_path, timestep_dim=256, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
clip_mean, clip_std = torch.load(clip_stats_path, map_location="cpu")
|
||||||
|
self.register_buffer("data_mean", clip_mean[None, :], persistent=False)
|
||||||
|
self.register_buffer("data_std", clip_std[None, :], persistent=False)
|
||||||
|
self.time_embed = Timestep(timestep_dim)
|
||||||
|
|
||||||
|
def scale(self, x):
|
||||||
|
# re-normalize to centered mean and unit variance
|
||||||
|
x = (x - self.data_mean) * 1./self.data_std
|
||||||
|
return x
|
||||||
|
|
||||||
|
def unscale(self, x):
|
||||||
|
# back to original data stats
|
||||||
|
x = (x * self.data_std) + self.data_mean
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(self, x, noise_level=None):
|
||||||
|
if noise_level is None:
|
||||||
|
noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long()
|
||||||
|
else:
|
||||||
|
assert isinstance(noise_level, torch.Tensor)
|
||||||
|
x = self.scale(x)
|
||||||
|
z = self.q_sample(x, noise_level)
|
||||||
|
z = self.unscale(z)
|
||||||
|
noise_level = self.time_embed(noise_level)
|
||||||
|
return z, noise_level
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue