From 8ec7903da28e17936ff9cbbb824cebd826db113b Mon Sep 17 00:00:00 2001 From: Robin Rombach Date: Wed, 18 Jan 2023 09:54:45 +0100 Subject: [PATCH] add noise-augmented unCLIP --- README.md | 2 +- .../v2-1-stable-karlo-inference.yaml | 11 ++++++- ldm/models/diffusion/ddpm.py | 19 ++++++++++- ldm/modules/diffusionmodules/openaimodel.py | 9 ++++++ ldm/modules/diffusionmodules/util.py | 7 ++++ ldm/modules/encoders/modules.py | 32 ++++++++++++++++++- 6 files changed, 76 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 9d72b85..c06f3c0 100644 --- a/README.md +++ b/README.md @@ -158,7 +158,7 @@ cd ../../ ``` and the finetuned SD2.1 checkpoint [+++prelim private upload on HF+++] from [https://huggingface.co/stabilityai/stable-unclip-preview](https://huggingface.co/stabilityai/stable-unclip-preview), and put the ckpt into the `checkpoints folder` -The, run +Then, run ``` streamlit run scripts/streamlit/stablekarlo.py diff --git a/configs/stable-diffusion/v2-1-stable-karlo-inference.yaml b/configs/stable-diffusion/v2-1-stable-karlo-inference.yaml index 5aeb176..ea8fa93 100644 --- a/configs/stable-diffusion/v2-1-stable-karlo-inference.yaml +++ b/configs/stable-diffusion/v2-1-stable-karlo-inference.yaml @@ -23,11 +23,20 @@ model: params: model: "ViT-L/14" + noise_aug_config: + target: ldm.modules.encoders.modules.CLIPEmbeddingNoiseAugmentation + params: + clip_stats_path: "checkpoints/karlo_models/ViT-L-14_stats.th" + timestep_dim: 768 + noise_schedule_config: + timesteps: 1000 + beta_schedule: squaredcos_cap_v2 + unet_config: target: ldm.modules.diffusionmodules.openaimodel.UNetModel params: num_classes: "sequential" - adm_in_channels: 768 + adm_in_channels: 1536 use_checkpoint: True image_size: 32 # unused in_channels: 4 diff --git a/ldm/models/diffusion/ddpm.py b/ldm/models/diffusion/ddpm.py index bde253f..81bc6d3 100644 --- a/ldm/models/diffusion/ddpm.py +++ b/ldm/models/diffusion/ddpm.py @@ -1796,11 +1796,13 @@ class LatentUpscaleFinetuneDiffusion(LatentFinetuneDiffusion): class ImageEmbeddingConditionedLatentDiffusion(LatentDiffusion): - def __init__(self, embedder_config, embedding_key="jpg", embedding_dropout=0.5, freeze_embedder=True, *args, **kwargs): + def __init__(self, embedder_config, embedding_key="jpg", embedding_dropout=0.5, + freeze_embedder=True, noise_aug_config=None, *args, **kwargs): super().__init__(*args, **kwargs) self.embed_key = embedding_key self.embedding_dropout = embedding_dropout self._init_embedder(embedder_config, freeze_embedder) + self._init_noise_aug(noise_aug_config) def _init_embedder(self, config, freeze=True): embedder = instantiate_from_config(config) @@ -1810,12 +1812,27 @@ class ImageEmbeddingConditionedLatentDiffusion(LatentDiffusion): for param in self.embedder.parameters(): param.requires_grad = False + def _init_noise_aug(self, config): + if config is not None: + # use the KARLO schedule for noise augmentation on CLIP image embeddings + noise_augmentor = instantiate_from_config(config) + assert isinstance(noise_augmentor, nn.Module) + noise_augmentor = noise_augmentor.eval() + noise_augmentor.train = disabled_train + self.noise_augmentor = noise_augmentor + else: + self.noise_augmentor = None + def get_input(self, batch, k, cond_key=None, bs=None, **kwargs): outputs = LatentDiffusion.get_input(self, batch, k, bs=bs, **kwargs) z, c = outputs[0], outputs[1] img = batch[self.embed_key][:bs] img = rearrange(img, 'b h w c -> b c h w') c_adm = self.embedder(img) + if self.noise_augmentor is not None: + c_adm, noise_level_emb = self.noise_augmentor(c_adm) + # assume this gives embeddings of noise levels + c_adm = torch.cat((c_adm, noise_level_emb), 1) if self.training: c_adm = torch.bernoulli((1. - self.embedding_dropout) * torch.ones(c_adm.shape[0], device=c_adm.device)[:, None]) * c_adm diff --git a/ldm/modules/diffusionmodules/openaimodel.py b/ldm/modules/diffusionmodules/openaimodel.py index 157d3b2..b5da99a 100644 --- a/ldm/modules/diffusionmodules/openaimodel.py +++ b/ldm/modules/diffusionmodules/openaimodel.py @@ -409,6 +409,15 @@ class QKVAttention(nn.Module): return count_flops_attn(model, _x, y) +class Timestep(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, t): + return timestep_embedding(t, self.dim) + + class UNetModel(nn.Module): """ The full UNet model with attention and timestep embedding. diff --git a/ldm/modules/diffusionmodules/util.py b/ldm/modules/diffusionmodules/util.py index 637363d..99f6829 100644 --- a/ldm/modules/diffusionmodules/util.py +++ b/ldm/modules/diffusionmodules/util.py @@ -34,6 +34,13 @@ def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, betas = 1 - alphas[1:] / alphas[:-1] betas = np.clip(betas, a_min=0, a_max=0.999) + elif schedule == "squaredcos_cap_v2": # used for karlo prior + # return early + return betas_for_alpha_bar( + n_timestep, + lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, + ) + elif schedule == "sqrt_linear": betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) elif schedule == "sqrt": diff --git a/ldm/modules/encoders/modules.py b/ldm/modules/encoders/modules.py index 520bb27..fcc5826 100644 --- a/ldm/modules/encoders/modules.py +++ b/ldm/modules/encoders/modules.py @@ -170,7 +170,6 @@ class ClipImageEmbedder(nn.Module): return out - class FrozenOpenCLIPEmbedder(AbstractEncoder): """ Uses the OpenCLIP transformer encoder for text @@ -251,3 +250,34 @@ class FrozenCLIPT5Encoder(AbstractEncoder): return [clip_z, t5_z] +from ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation +from ldm.modules.diffusionmodules.openaimodel import Timestep +class CLIPEmbeddingNoiseAugmentation(ImageConcatWithNoiseAugmentation): + def __init__(self, *args, clip_stats_path, timestep_dim=256, **kwargs): + super().__init__(*args, **kwargs) + clip_mean, clip_std = torch.load(clip_stats_path, map_location="cpu") + self.register_buffer("data_mean", clip_mean[None, :], persistent=False) + self.register_buffer("data_std", clip_std[None, :], persistent=False) + self.time_embed = Timestep(timestep_dim) + + def scale(self, x): + # re-normalize to centered mean and unit variance + x = (x - self.data_mean) * 1./self.data_std + return x + + def unscale(self, x): + # back to original data stats + x = (x * self.data_std) + self.data_mean + return x + + def forward(self, x, noise_level=None): + if noise_level is None: + noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long() + else: + assert isinstance(noise_level, torch.Tensor) + x = self.scale(x) + z = self.q_sample(x, noise_level) + z = self.unscale(z) + noise_level = self.time_embed(noise_level) + return z, noise_level +