add noise-augmented unCLIP

This commit is contained in:
Robin Rombach 2023-01-18 09:54:45 +01:00
parent aad6e38a78
commit 8ec7903da2
6 changed files with 76 additions and 4 deletions

View file

@ -158,7 +158,7 @@ cd ../../
``` ```
and the finetuned SD2.1 checkpoint [+++prelim private upload on HF+++] from [https://huggingface.co/stabilityai/stable-unclip-preview](https://huggingface.co/stabilityai/stable-unclip-preview), and put the ckpt into the `checkpoints folder` and the finetuned SD2.1 checkpoint [+++prelim private upload on HF+++] from [https://huggingface.co/stabilityai/stable-unclip-preview](https://huggingface.co/stabilityai/stable-unclip-preview), and put the ckpt into the `checkpoints folder`
The, run Then, run
``` ```
streamlit run scripts/streamlit/stablekarlo.py streamlit run scripts/streamlit/stablekarlo.py

View file

@ -23,11 +23,20 @@ model:
params: params:
model: "ViT-L/14" model: "ViT-L/14"
noise_aug_config:
target: ldm.modules.encoders.modules.CLIPEmbeddingNoiseAugmentation
params:
clip_stats_path: "checkpoints/karlo_models/ViT-L-14_stats.th"
timestep_dim: 768
noise_schedule_config:
timesteps: 1000
beta_schedule: squaredcos_cap_v2
unet_config: unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params: params:
num_classes: "sequential" num_classes: "sequential"
adm_in_channels: 768 adm_in_channels: 1536
use_checkpoint: True use_checkpoint: True
image_size: 32 # unused image_size: 32 # unused
in_channels: 4 in_channels: 4

View file

@ -1796,11 +1796,13 @@ class LatentUpscaleFinetuneDiffusion(LatentFinetuneDiffusion):
class ImageEmbeddingConditionedLatentDiffusion(LatentDiffusion): class ImageEmbeddingConditionedLatentDiffusion(LatentDiffusion):
def __init__(self, embedder_config, embedding_key="jpg", embedding_dropout=0.5, freeze_embedder=True, *args, **kwargs): def __init__(self, embedder_config, embedding_key="jpg", embedding_dropout=0.5,
freeze_embedder=True, noise_aug_config=None, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.embed_key = embedding_key self.embed_key = embedding_key
self.embedding_dropout = embedding_dropout self.embedding_dropout = embedding_dropout
self._init_embedder(embedder_config, freeze_embedder) self._init_embedder(embedder_config, freeze_embedder)
self._init_noise_aug(noise_aug_config)
def _init_embedder(self, config, freeze=True): def _init_embedder(self, config, freeze=True):
embedder = instantiate_from_config(config) embedder = instantiate_from_config(config)
@ -1810,12 +1812,27 @@ class ImageEmbeddingConditionedLatentDiffusion(LatentDiffusion):
for param in self.embedder.parameters(): for param in self.embedder.parameters():
param.requires_grad = False param.requires_grad = False
def _init_noise_aug(self, config):
if config is not None:
# use the KARLO schedule for noise augmentation on CLIP image embeddings
noise_augmentor = instantiate_from_config(config)
assert isinstance(noise_augmentor, nn.Module)
noise_augmentor = noise_augmentor.eval()
noise_augmentor.train = disabled_train
self.noise_augmentor = noise_augmentor
else:
self.noise_augmentor = None
def get_input(self, batch, k, cond_key=None, bs=None, **kwargs): def get_input(self, batch, k, cond_key=None, bs=None, **kwargs):
outputs = LatentDiffusion.get_input(self, batch, k, bs=bs, **kwargs) outputs = LatentDiffusion.get_input(self, batch, k, bs=bs, **kwargs)
z, c = outputs[0], outputs[1] z, c = outputs[0], outputs[1]
img = batch[self.embed_key][:bs] img = batch[self.embed_key][:bs]
img = rearrange(img, 'b h w c -> b c h w') img = rearrange(img, 'b h w c -> b c h w')
c_adm = self.embedder(img) c_adm = self.embedder(img)
if self.noise_augmentor is not None:
c_adm, noise_level_emb = self.noise_augmentor(c_adm)
# assume this gives embeddings of noise levels
c_adm = torch.cat((c_adm, noise_level_emb), 1)
if self.training: if self.training:
c_adm = torch.bernoulli((1. - self.embedding_dropout) * torch.ones(c_adm.shape[0], c_adm = torch.bernoulli((1. - self.embedding_dropout) * torch.ones(c_adm.shape[0],
device=c_adm.device)[:, None]) * c_adm device=c_adm.device)[:, None]) * c_adm

View file

@ -409,6 +409,15 @@ class QKVAttention(nn.Module):
return count_flops_attn(model, _x, y) return count_flops_attn(model, _x, y)
class Timestep(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, t):
return timestep_embedding(t, self.dim)
class UNetModel(nn.Module): class UNetModel(nn.Module):
""" """
The full UNet model with attention and timestep embedding. The full UNet model with attention and timestep embedding.

View file

@ -34,6 +34,13 @@ def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2,
betas = 1 - alphas[1:] / alphas[:-1] betas = 1 - alphas[1:] / alphas[:-1]
betas = np.clip(betas, a_min=0, a_max=0.999) betas = np.clip(betas, a_min=0, a_max=0.999)
elif schedule == "squaredcos_cap_v2": # used for karlo prior
# return early
return betas_for_alpha_bar(
n_timestep,
lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
)
elif schedule == "sqrt_linear": elif schedule == "sqrt_linear":
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
elif schedule == "sqrt": elif schedule == "sqrt":

View file

@ -170,7 +170,6 @@ class ClipImageEmbedder(nn.Module):
return out return out
class FrozenOpenCLIPEmbedder(AbstractEncoder): class FrozenOpenCLIPEmbedder(AbstractEncoder):
""" """
Uses the OpenCLIP transformer encoder for text Uses the OpenCLIP transformer encoder for text
@ -251,3 +250,34 @@ class FrozenCLIPT5Encoder(AbstractEncoder):
return [clip_z, t5_z] return [clip_z, t5_z]
from ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
from ldm.modules.diffusionmodules.openaimodel import Timestep
class CLIPEmbeddingNoiseAugmentation(ImageConcatWithNoiseAugmentation):
def __init__(self, *args, clip_stats_path, timestep_dim=256, **kwargs):
super().__init__(*args, **kwargs)
clip_mean, clip_std = torch.load(clip_stats_path, map_location="cpu")
self.register_buffer("data_mean", clip_mean[None, :], persistent=False)
self.register_buffer("data_std", clip_std[None, :], persistent=False)
self.time_embed = Timestep(timestep_dim)
def scale(self, x):
# re-normalize to centered mean and unit variance
x = (x - self.data_mean) * 1./self.data_std
return x
def unscale(self, x):
# back to original data stats
x = (x * self.data_std) + self.data_mean
return x
def forward(self, x, noise_level=None):
if noise_level is None:
noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long()
else:
assert isinstance(noise_level, torch.Tensor)
x = self.scale(x)
z = self.q_sample(x, noise_level)
z = self.unscale(z)
noise_level = self.time_embed(noise_level)
return z, noise_level