mirror of
https://github.com/Stability-AI/stablediffusion.git
synced 2024-12-22 07:34:58 +00:00
add noise-augmented unCLIP
This commit is contained in:
parent
aad6e38a78
commit
8ec7903da2
6 changed files with 76 additions and 4 deletions
|
@ -158,7 +158,7 @@ cd ../../
|
|||
```
|
||||
and the finetuned SD2.1 checkpoint [+++prelim private upload on HF+++] from [https://huggingface.co/stabilityai/stable-unclip-preview](https://huggingface.co/stabilityai/stable-unclip-preview), and put the ckpt into the `checkpoints folder`
|
||||
|
||||
The, run
|
||||
Then, run
|
||||
|
||||
```
|
||||
streamlit run scripts/streamlit/stablekarlo.py
|
||||
|
|
|
@ -23,11 +23,20 @@ model:
|
|||
params:
|
||||
model: "ViT-L/14"
|
||||
|
||||
noise_aug_config:
|
||||
target: ldm.modules.encoders.modules.CLIPEmbeddingNoiseAugmentation
|
||||
params:
|
||||
clip_stats_path: "checkpoints/karlo_models/ViT-L-14_stats.th"
|
||||
timestep_dim: 768
|
||||
noise_schedule_config:
|
||||
timesteps: 1000
|
||||
beta_schedule: squaredcos_cap_v2
|
||||
|
||||
unet_config:
|
||||
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||
params:
|
||||
num_classes: "sequential"
|
||||
adm_in_channels: 768
|
||||
adm_in_channels: 1536
|
||||
use_checkpoint: True
|
||||
image_size: 32 # unused
|
||||
in_channels: 4
|
||||
|
|
|
@ -1796,11 +1796,13 @@ class LatentUpscaleFinetuneDiffusion(LatentFinetuneDiffusion):
|
|||
|
||||
|
||||
class ImageEmbeddingConditionedLatentDiffusion(LatentDiffusion):
|
||||
def __init__(self, embedder_config, embedding_key="jpg", embedding_dropout=0.5, freeze_embedder=True, *args, **kwargs):
|
||||
def __init__(self, embedder_config, embedding_key="jpg", embedding_dropout=0.5,
|
||||
freeze_embedder=True, noise_aug_config=None, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.embed_key = embedding_key
|
||||
self.embedding_dropout = embedding_dropout
|
||||
self._init_embedder(embedder_config, freeze_embedder)
|
||||
self._init_noise_aug(noise_aug_config)
|
||||
|
||||
def _init_embedder(self, config, freeze=True):
|
||||
embedder = instantiate_from_config(config)
|
||||
|
@ -1810,12 +1812,27 @@ class ImageEmbeddingConditionedLatentDiffusion(LatentDiffusion):
|
|||
for param in self.embedder.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def _init_noise_aug(self, config):
|
||||
if config is not None:
|
||||
# use the KARLO schedule for noise augmentation on CLIP image embeddings
|
||||
noise_augmentor = instantiate_from_config(config)
|
||||
assert isinstance(noise_augmentor, nn.Module)
|
||||
noise_augmentor = noise_augmentor.eval()
|
||||
noise_augmentor.train = disabled_train
|
||||
self.noise_augmentor = noise_augmentor
|
||||
else:
|
||||
self.noise_augmentor = None
|
||||
|
||||
def get_input(self, batch, k, cond_key=None, bs=None, **kwargs):
|
||||
outputs = LatentDiffusion.get_input(self, batch, k, bs=bs, **kwargs)
|
||||
z, c = outputs[0], outputs[1]
|
||||
img = batch[self.embed_key][:bs]
|
||||
img = rearrange(img, 'b h w c -> b c h w')
|
||||
c_adm = self.embedder(img)
|
||||
if self.noise_augmentor is not None:
|
||||
c_adm, noise_level_emb = self.noise_augmentor(c_adm)
|
||||
# assume this gives embeddings of noise levels
|
||||
c_adm = torch.cat((c_adm, noise_level_emb), 1)
|
||||
if self.training:
|
||||
c_adm = torch.bernoulli((1. - self.embedding_dropout) * torch.ones(c_adm.shape[0],
|
||||
device=c_adm.device)[:, None]) * c_adm
|
||||
|
|
|
@ -409,6 +409,15 @@ class QKVAttention(nn.Module):
|
|||
return count_flops_attn(model, _x, y)
|
||||
|
||||
|
||||
class Timestep(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
||||
def forward(self, t):
|
||||
return timestep_embedding(t, self.dim)
|
||||
|
||||
|
||||
class UNetModel(nn.Module):
|
||||
"""
|
||||
The full UNet model with attention and timestep embedding.
|
||||
|
|
|
@ -34,6 +34,13 @@ def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2,
|
|||
betas = 1 - alphas[1:] / alphas[:-1]
|
||||
betas = np.clip(betas, a_min=0, a_max=0.999)
|
||||
|
||||
elif schedule == "squaredcos_cap_v2": # used for karlo prior
|
||||
# return early
|
||||
return betas_for_alpha_bar(
|
||||
n_timestep,
|
||||
lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
|
||||
)
|
||||
|
||||
elif schedule == "sqrt_linear":
|
||||
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
|
||||
elif schedule == "sqrt":
|
||||
|
|
|
@ -170,7 +170,6 @@ class ClipImageEmbedder(nn.Module):
|
|||
return out
|
||||
|
||||
|
||||
|
||||
class FrozenOpenCLIPEmbedder(AbstractEncoder):
|
||||
"""
|
||||
Uses the OpenCLIP transformer encoder for text
|
||||
|
@ -251,3 +250,34 @@ class FrozenCLIPT5Encoder(AbstractEncoder):
|
|||
return [clip_z, t5_z]
|
||||
|
||||
|
||||
from ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
|
||||
from ldm.modules.diffusionmodules.openaimodel import Timestep
|
||||
class CLIPEmbeddingNoiseAugmentation(ImageConcatWithNoiseAugmentation):
|
||||
def __init__(self, *args, clip_stats_path, timestep_dim=256, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
clip_mean, clip_std = torch.load(clip_stats_path, map_location="cpu")
|
||||
self.register_buffer("data_mean", clip_mean[None, :], persistent=False)
|
||||
self.register_buffer("data_std", clip_std[None, :], persistent=False)
|
||||
self.time_embed = Timestep(timestep_dim)
|
||||
|
||||
def scale(self, x):
|
||||
# re-normalize to centered mean and unit variance
|
||||
x = (x - self.data_mean) * 1./self.data_std
|
||||
return x
|
||||
|
||||
def unscale(self, x):
|
||||
# back to original data stats
|
||||
x = (x * self.data_std) + self.data_mean
|
||||
return x
|
||||
|
||||
def forward(self, x, noise_level=None):
|
||||
if noise_level is None:
|
||||
noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long()
|
||||
else:
|
||||
assert isinstance(noise_level, torch.Tensor)
|
||||
x = self.scale(x)
|
||||
z = self.q_sample(x, noise_level)
|
||||
z = self.unscale(z)
|
||||
noise_level = self.time_embed(noise_level)
|
||||
return z, noise_level
|
||||
|
||||
|
|
Loading…
Reference in a new issue