diff --git a/README.md b/README.md index f413068..9d72b85 100644 --- a/README.md +++ b/README.md @@ -137,6 +137,42 @@ Note: The inference config for all model versions is designed to be used with EM For this reason `use_ema=False` is set in the configuration, otherwise the code will try to switch from non-EMA to EMA weights. +### Stable Diffusion Meets Karlo +![upscaling-x4](assets/stable-samples/stable-unclip/panda.jpg) +_++++++ NOTE: preliminary checkpoint for internal testing ++++++_ + +Recently, [KakaoBrain](https://kakaobrain.com/) openly released [Karlo](https://github.com/kakaobrain/karlo), a pretrained, large-scale replication of [unCLIP](https://arxiv.org/abs/2204.06125) (also known as DALLĀ·E 2). +We introduce _Stable Karlo_, a combination of the Karlo CLIP image embedding prior, and Stable Diffusion v2.1. +More precisely, we finetuned SD 2.1 to accept a CLIP ViT-L/14 image embedding in addition to the text encodings. +This means that the model can be used to produce image variations in the style of unCLIP, but can also be combined with the +embedding prior of KARLO and directly decodes to 768x768 pixel resolution. + +To run the model, first download the KARLO checkpoints +```shell +mkdir -p checkpoints/karlo_models +cd checkpoints/karlo_models +wget https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/096db1af569b284eb76b3881534822d9/ViT-L-14.pt +wget https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/0b62380a75e56f073e2844ab5199153d/ViT-L-14_stats.th +wget https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/85626483eaca9f581e2a78d31ff905ca/prior-ckpt-step%3D01000000-of-01000000.ckpt +cd ../../ +``` +and the finetuned SD2.1 checkpoint [+++prelim private upload on HF+++] from [https://huggingface.co/stabilityai/stable-unclip-preview](https://huggingface.co/stabilityai/stable-unclip-preview), and put the ckpt into the `checkpoints folder` + +The, run + +``` +streamlit run scripts/streamlit/stablekarlo.py +``` + +The script optionally supports sampling from the full Karlo model. To do so, you need to download the 64x64 decoder and 64->256 upscaler +via +```shell +cd checkpoints/karlo_models +wget https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/efdf6206d8ed593961593dc029a8affa/decoder-ckpt-step%3D01000000-of-01000000.ckpt +wget https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/4226b831ae0279020d134281f3c31590/improved-sr-ckpt-step%3D1.2M.ckpt +cd ../../ +``` + ### Image Modification with Stable Diffusion ![depth2img-stable2](assets/stable-samples/depth2img/merged-0000.png) diff --git a/assets/stable-samples/stable-unclip/panda.jpg b/assets/stable-samples/stable-unclip/panda.jpg new file mode 100644 index 0000000..49aa1ba Binary files /dev/null and b/assets/stable-samples/stable-unclip/panda.jpg differ diff --git a/configs/karlo/decoder_900M_vit_l.yaml b/configs/karlo/decoder_900M_vit_l.yaml new file mode 100644 index 0000000..02a3530 --- /dev/null +++ b/configs/karlo/decoder_900M_vit_l.yaml @@ -0,0 +1,37 @@ +model: + type: t2i-decoder + diffusion_sampler: uniform + hparams: + image_size: 64 + num_channels: 320 + num_res_blocks: 3 + channel_mult: '' + attention_resolutions: 32,16,8 + num_heads: -1 + num_head_channels: 64 + num_heads_upsample: -1 + use_scale_shift_norm: true + dropout: 0.1 + clip_dim: 768 + clip_emb_mult: 4 + text_ctx: 77 + xf_width: 1536 + xf_layers: 0 + xf_heads: 0 + xf_final_ln: false + resblock_updown: true + learn_sigma: true + text_drop: 0.3 + clip_emb_type: image + clip_emb_drop: 0.1 + use_plm: true + +diffusion: + steps: 1000 + learn_sigma: true + sigma_small: false + noise_schedule: squaredcos_cap_v2 + use_kl: false + predict_xstart: false + rescale_learned_sigmas: true + timestep_respacing: '' diff --git a/configs/karlo/improved_sr_64_256_1.4B.yaml b/configs/karlo/improved_sr_64_256_1.4B.yaml new file mode 100644 index 0000000..282d3cb --- /dev/null +++ b/configs/karlo/improved_sr_64_256_1.4B.yaml @@ -0,0 +1,27 @@ +model: + type: improved_sr_64_256 + diffusion_sampler: uniform + hparams: + channels: 320 + depth: 3 + channels_multiple: + - 1 + - 2 + - 3 + - 4 + dropout: 0.0 + +diffusion: + steps: 1000 + learn_sigma: false + sigma_small: true + noise_schedule: squaredcos_cap_v2 + use_kl: false + predict_xstart: false + rescale_learned_sigmas: true + timestep_respacing: '7' + + +sampling: + timestep_respacing: '7' # fix + clip_denoise: true diff --git a/configs/karlo/prior_1B_vit_l.yaml b/configs/karlo/prior_1B_vit_l.yaml new file mode 100644 index 0000000..159330d --- /dev/null +++ b/configs/karlo/prior_1B_vit_l.yaml @@ -0,0 +1,21 @@ +model: + type: prior + diffusion_sampler: uniform + hparams: + text_ctx: 77 + xf_width: 2048 + xf_layers: 20 + xf_heads: 32 + xf_final_ln: true + text_drop: 0.2 + clip_dim: 768 + +diffusion: + steps: 1000 + learn_sigma: false + sigma_small: true + noise_schedule: squaredcos_cap_v2 + use_kl: false + predict_xstart: true + rescale_learned_sigmas: false + timestep_respacing: '' diff --git a/configs/stable-diffusion/v2-1-stable-karlo-inference.yaml b/configs/stable-diffusion/v2-1-stable-karlo-inference.yaml new file mode 100644 index 0000000..da867b4 --- /dev/null +++ b/configs/stable-diffusion/v2-1-stable-karlo-inference.yaml @@ -0,0 +1,74 @@ +model: + base_learning_rate: 1.0e-04 + target: ldm.models.diffusion.ddpm.ImageEmbeddingConditionedLatentDiffusion + params: + embedding_dropout: 0.25 + parameterization: "v" + linear_start: 0.00085 + linear_end: 0.0120 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "jpg" + cond_stage_key: "txt" + image_size: 96 + channels: 4 + cond_stage_trainable: false + conditioning_key: crossattn-adm + scale_factor: 0.18215 + monitor: val/loss_simple_ema + + embedder_config: + target: ldm.modules.encoders.modules.ClipImageEmbedder + params: + model: "ViT-L/14" + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + num_classes: "sequential" + adm_in_channels: 768 + use_checkpoint: True + image_size: 32 # unused + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_head_channels: 64 # need to fix for flash-attn + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: 1 + context_dim: 1024 + spatial_transformer_attn_type: "softmax-xformers" + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + attn_type: "vanilla-xformers" + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [ ] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder + params: + freeze: True + layer: "penultimate" \ No newline at end of file diff --git a/ldm/models/diffusion/ddpm.py b/ldm/models/diffusion/ddpm.py index 6090212..bde253f 100644 --- a/ldm/models/diffusion/ddpm.py +++ b/ldm/models/diffusion/ddpm.py @@ -1793,3 +1793,58 @@ class LatentUpscaleFinetuneDiffusion(LatentFinetuneDiffusion): log = super().log_images(*args, **kwargs) log["lr"] = rearrange(args[0]["lr"], 'b h w c -> b c h w') return log + + +class ImageEmbeddingConditionedLatentDiffusion(LatentDiffusion): + def __init__(self, embedder_config, embedding_key="jpg", embedding_dropout=0.5, freeze_embedder=True, *args, **kwargs): + super().__init__(*args, **kwargs) + self.embed_key = embedding_key + self.embedding_dropout = embedding_dropout + self._init_embedder(embedder_config, freeze_embedder) + + def _init_embedder(self, config, freeze=True): + embedder = instantiate_from_config(config) + if freeze: + self.embedder = embedder.eval() + self.embedder.train = disabled_train + for param in self.embedder.parameters(): + param.requires_grad = False + + def get_input(self, batch, k, cond_key=None, bs=None, **kwargs): + outputs = LatentDiffusion.get_input(self, batch, k, bs=bs, **kwargs) + z, c = outputs[0], outputs[1] + img = batch[self.embed_key][:bs] + img = rearrange(img, 'b h w c -> b c h w') + c_adm = self.embedder(img) + if self.training: + c_adm = torch.bernoulli((1. - self.embedding_dropout) * torch.ones(c_adm.shape[0], + device=c_adm.device)[:, None]) * c_adm + all_conds = {"c_crossattn": [c], "c_adm": c_adm} + noutputs = [z, all_conds] + noutputs.extend(outputs[2:]) + return noutputs + + @torch.no_grad() + def log_images(self, batch, N=8, n_row=4, **kwargs): + log = dict() + z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, bs=N, return_first_stage_outputs=True, + return_original_cond=True) + log["inputs"] = x + log["reconstruction"] = xrec + assert self.model.conditioning_key is not None + assert self.cond_stage_key in ["caption", "txt"] + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25) + log["conditioning"] = xc + uc = self.get_unconditional_conditioning(N, kwargs.get('unconditional_guidance_label', '')) + unconditional_guidance_scale = kwargs.get('unconditional_guidance_scale', 5.) + + uc_ = {"c_crossattn": [uc], "c_adm": c["c_adm"]} + ema_scope = self.ema_scope if kwargs.get('use_ema_scope', True) else nullcontext + with ema_scope(f"Sampling"): + samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=True, + ddim_steps=kwargs.get('ddim_steps', 50), eta=kwargs.get('ddim_eta', 0.), + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=uc_, ) + x_samples_cfg = self.decode_first_stage(samples_cfg) + log[f"samplescfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg + return log diff --git a/ldm/modules/karlo/__init__.py b/ldm/modules/karlo/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/ldm/modules/karlo/diffusers_pipeline.py b/ldm/modules/karlo/diffusers_pipeline.py new file mode 100644 index 0000000..07f72b3 --- /dev/null +++ b/ldm/modules/karlo/diffusers_pipeline.py @@ -0,0 +1,512 @@ +# Copyright 2022 Kakao Brain and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import List, Optional, Tuple, Union + +import torch +from torch.nn import functional as F + +from transformers import CLIPTextModelWithProjection, CLIPTokenizer +from transformers.models.clip.modeling_clip import CLIPTextModelOutput + +from ...models import PriorTransformer, UNet2DConditionModel, UNet2DModel +from ...pipelines import DiffusionPipeline, ImagePipelineOutput +from ...schedulers import UnCLIPScheduler +from ...utils import is_accelerate_available, logging, randn_tensor +from .text_proj import UnCLIPTextProjModel + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class UnCLIPPipeline(DiffusionPipeline): + """ + Pipeline for text-to-image generation using unCLIP + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + Args: + text_encoder ([`CLIPTextModelWithProjection`]): + Frozen text-encoder. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + prior ([`PriorTransformer`]): + The canonincal unCLIP prior to approximate the image embedding from the text embedding. + text_proj ([`UnCLIPTextProjModel`]): + Utility class to prepare and combine the embeddings before they are passed to the decoder. + decoder ([`UNet2DConditionModel`]): + The decoder to invert the image embedding into an image. + super_res_first ([`UNet2DModel`]): + Super resolution unet. Used in all but the last step of the super resolution diffusion process. + super_res_last ([`UNet2DModel`]): + Super resolution unet. Used in the last step of the super resolution diffusion process. + prior_scheduler ([`UnCLIPScheduler`]): + Scheduler used in the prior denoising process. Just a modified DDPMScheduler. + decoder_scheduler ([`UnCLIPScheduler`]): + Scheduler used in the decoder denoising process. Just a modified DDPMScheduler. + super_res_scheduler ([`UnCLIPScheduler`]): + Scheduler used in the super resolution denoising process. Just a modified DDPMScheduler. + """ + + prior: PriorTransformer + decoder: UNet2DConditionModel + text_proj: UnCLIPTextProjModel + text_encoder: CLIPTextModelWithProjection + tokenizer: CLIPTokenizer + super_res_first: UNet2DModel + super_res_last: UNet2DModel + + prior_scheduler: UnCLIPScheduler + decoder_scheduler: UnCLIPScheduler + super_res_scheduler: UnCLIPScheduler + + def __init__( + self, + prior: PriorTransformer, + decoder: UNet2DConditionModel, + text_encoder: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + text_proj: UnCLIPTextProjModel, + super_res_first: UNet2DModel, + super_res_last: UNet2DModel, + prior_scheduler: UnCLIPScheduler, + decoder_scheduler: UnCLIPScheduler, + super_res_scheduler: UnCLIPScheduler, + ): + super().__init__() + + self.register_modules( + prior=prior, + decoder=decoder, + text_encoder=text_encoder, + tokenizer=tokenizer, + text_proj=text_proj, + super_res_first=super_res_first, + super_res_last=super_res_last, + prior_scheduler=prior_scheduler, + decoder_scheduler=decoder_scheduler, + super_res_scheduler=super_res_scheduler, + ) + + def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + latents = latents * scheduler.init_noise_sigma + return latents + + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + text_model_output: Optional[Union[CLIPTextModelOutput, Tuple]] = None, + text_attention_mask: Optional[torch.Tensor] = None, + ): + if text_model_output is None: + batch_size = len(prompt) if isinstance(prompt, list) else 1 + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + text_mask = text_inputs.attention_mask.bool().to(device) + + if text_input_ids.shape[-1] > self.tokenizer.model_max_length: + removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] + + text_encoder_output = self.text_encoder(text_input_ids.to(device)) + + text_embeddings = text_encoder_output.text_embeds + text_encoder_hidden_states = text_encoder_output.last_hidden_state + + else: + batch_size = text_model_output[0].shape[0] + text_embeddings, text_encoder_hidden_states = text_model_output[0], text_model_output[1] + text_mask = text_attention_mask + + text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0) + text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0) + + if do_classifier_free_guidance: + uncond_tokens = [""] * batch_size + + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + uncond_text_mask = uncond_input.attention_mask.bool().to(device) + uncond_embeddings_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device)) + + uncond_embeddings = uncond_embeddings_text_encoder_output.text_embeds + uncond_text_encoder_hidden_states = uncond_embeddings_text_encoder_output.last_hidden_state + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + + seq_len = uncond_embeddings.shape[1] + uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt) + uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len) + + seq_len = uncond_text_encoder_hidden_states.shape[1] + uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1) + uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) + uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0) + + # done duplicates + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states]) + + text_mask = torch.cat([uncond_text_mask, text_mask]) + + return text_embeddings, text_encoder_hidden_states, text_mask + + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's + models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only + when their specific submodule has its `forward` method called. + """ + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = torch.device(f"cuda:{gpu_id}") + + # TODO: self.prior.post_process_latents is not covered by the offload hooks, so it fails if added to the list + models = [ + self.decoder, + self.text_proj, + self.text_encoder, + self.super_res_first, + self.super_res_last, + ] + for cpu_offloaded_model in models: + if cpu_offloaded_model is not None: + cpu_offload(cpu_offloaded_model, device) + + @property + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if self.device != torch.device("meta") or not hasattr(self.decoder, "_hf_hook"): + return self.device + for module in self.decoder.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + @torch.no_grad() + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: int = 1, + prior_num_inference_steps: int = 25, + decoder_num_inference_steps: int = 25, + super_res_num_inference_steps: int = 7, + generator: Optional[torch.Generator] = None, + prior_latents: Optional[torch.FloatTensor] = None, + decoder_latents: Optional[torch.FloatTensor] = None, + super_res_latents: Optional[torch.FloatTensor] = None, + text_model_output: Optional[Union[CLIPTextModelOutput, Tuple]] = None, + text_attention_mask: Optional[torch.Tensor] = None, + prior_guidance_scale: float = 4.0, + decoder_guidance_scale: float = 8.0, + output_type: Optional[str] = "pil", + return_dict: bool = True, + ): + """ + Function invoked when calling the pipeline for generation. + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. This can only be left undefined if + `text_model_output` and `text_attention_mask` is passed. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prior_num_inference_steps (`int`, *optional*, defaults to 25): + The number of denoising steps for the prior. More denoising steps usually lead to a higher quality + image at the expense of slower inference. + decoder_num_inference_steps (`int`, *optional*, defaults to 25): + The number of denoising steps for the decoder. More denoising steps usually lead to a higher quality + image at the expense of slower inference. + super_res_num_inference_steps (`int`, *optional*, defaults to 7): + The number of denoising steps for super resolution. More denoising steps usually lead to a higher + quality image at the expense of slower inference. + generator (`torch.Generator`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + prior_latents (`torch.FloatTensor` of shape (batch size, embeddings dimension), *optional*): + Pre-generated noisy latents to be used as inputs for the prior. + decoder_latents (`torch.FloatTensor` of shape (batch size, channels, height, width), *optional*): + Pre-generated noisy latents to be used as inputs for the decoder. + super_res_latents (`torch.FloatTensor` of shape (batch size, channels, super res height, super res width), *optional*): + Pre-generated noisy latents to be used as inputs for the decoder. + prior_guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + decoder_guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + text_model_output (`CLIPTextModelOutput`, *optional*): + Pre-defined CLIPTextModel outputs that can be derived from the text encoder. Pre-defined text outputs + can be passed for tasks like text embedding interpolations. Make sure to also pass + `text_attention_mask` in this case. `prompt` can the be left to `None`. + text_attention_mask (`torch.Tensor`, *optional*): + Pre-defined CLIP text attention mask that can be derived from the tokenizer. Pre-defined text attention + masks are necessary when passing `text_model_output`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + """ + if prompt is not None: + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + else: + batch_size = text_model_output[0].shape[0] + + device = self._execution_device + + batch_size = batch_size * num_images_per_prompt + + do_classifier_free_guidance = prior_guidance_scale > 1.0 or decoder_guidance_scale > 1.0 + + text_embeddings, text_encoder_hidden_states, text_mask = self._encode_prompt( + prompt, device, num_images_per_prompt, do_classifier_free_guidance, text_model_output, text_attention_mask + ) + + # prior + + self.prior_scheduler.set_timesteps(prior_num_inference_steps, device=device) + prior_timesteps_tensor = self.prior_scheduler.timesteps + + embedding_dim = self.prior.config.embedding_dim + + prior_latents = self.prepare_latents( + (batch_size, embedding_dim), + text_embeddings.dtype, + device, + generator, + prior_latents, + self.prior_scheduler, + ) + + for i, t in enumerate(self.progress_bar(prior_timesteps_tensor)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([prior_latents] * 2) if do_classifier_free_guidance else prior_latents + + predicted_image_embedding = self.prior( + latent_model_input, + timestep=t, + proj_embedding=text_embeddings, + encoder_hidden_states=text_encoder_hidden_states, + attention_mask=text_mask, + ).predicted_image_embedding + + if do_classifier_free_guidance: + predicted_image_embedding_uncond, predicted_image_embedding_text = predicted_image_embedding.chunk(2) + predicted_image_embedding = predicted_image_embedding_uncond + prior_guidance_scale * ( + predicted_image_embedding_text - predicted_image_embedding_uncond + ) + + if i + 1 == prior_timesteps_tensor.shape[0]: + prev_timestep = None + else: + prev_timestep = prior_timesteps_tensor[i + 1] + + prior_latents = self.prior_scheduler.step( + predicted_image_embedding, + timestep=t, + sample=prior_latents, + generator=generator, + prev_timestep=prev_timestep, + ).prev_sample + + prior_latents = self.prior.post_process_latents(prior_latents) + + image_embeddings = prior_latents + + # done prior + + # decoder + + text_encoder_hidden_states, additive_clip_time_embeddings = self.text_proj( + image_embeddings=image_embeddings, + text_embeddings=text_embeddings, + text_encoder_hidden_states=text_encoder_hidden_states, + do_classifier_free_guidance=do_classifier_free_guidance, + ) + + decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=1) + + self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=device) + decoder_timesteps_tensor = self.decoder_scheduler.timesteps + + num_channels_latents = self.decoder.in_channels + height = self.decoder.sample_size + width = self.decoder.sample_size + + decoder_latents = self.prepare_latents( + (batch_size, num_channels_latents, height, width), + text_encoder_hidden_states.dtype, + device, + generator, + decoder_latents, + self.decoder_scheduler, + ) + + for i, t in enumerate(self.progress_bar(decoder_timesteps_tensor)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([decoder_latents] * 2) if do_classifier_free_guidance else decoder_latents + + noise_pred = self.decoder( + sample=latent_model_input, + timestep=t, + encoder_hidden_states=text_encoder_hidden_states, + class_labels=additive_clip_time_embeddings, + attention_mask=decoder_text_mask, + ).sample + + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred_uncond, _ = noise_pred_uncond.split(latent_model_input.shape[1], dim=1) + noise_pred_text, predicted_variance = noise_pred_text.split(latent_model_input.shape[1], dim=1) + noise_pred = noise_pred_uncond + decoder_guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) + + if i + 1 == decoder_timesteps_tensor.shape[0]: + prev_timestep = None + else: + prev_timestep = decoder_timesteps_tensor[i + 1] + + # compute the previous noisy sample x_t -> x_t-1 + decoder_latents = self.decoder_scheduler.step( + noise_pred, t, decoder_latents, prev_timestep=prev_timestep, generator=generator + ).prev_sample + + decoder_latents = decoder_latents.clamp(-1, 1) + + image_small = decoder_latents + + # done decoder + + # super res + + self.super_res_scheduler.set_timesteps(super_res_num_inference_steps, device=device) + super_res_timesteps_tensor = self.super_res_scheduler.timesteps + + channels = self.super_res_first.in_channels // 2 + height = self.super_res_first.sample_size + width = self.super_res_first.sample_size + + super_res_latents = self.prepare_latents( + (batch_size, channels, height, width), + image_small.dtype, + device, + generator, + super_res_latents, + self.super_res_scheduler, + ) + + interpolate_antialias = {} + if "antialias" in inspect.signature(F.interpolate).parameters: + interpolate_antialias["antialias"] = True + + image_upscaled = F.interpolate( + image_small, size=[height, width], mode="bicubic", align_corners=False, **interpolate_antialias + ) + + for i, t in enumerate(self.progress_bar(super_res_timesteps_tensor)): + # no classifier free guidance + + if i == super_res_timesteps_tensor.shape[0] - 1: + unet = self.super_res_last + else: + unet = self.super_res_first + + latent_model_input = torch.cat([super_res_latents, image_upscaled], dim=1) + + noise_pred = unet( + sample=latent_model_input, + timestep=t, + ).sample + + if i + 1 == super_res_timesteps_tensor.shape[0]: + prev_timestep = None + else: + prev_timestep = super_res_timesteps_tensor[i + 1] + + # compute the previous noisy sample x_t -> x_t-1 + super_res_latents = self.super_res_scheduler.step( + noise_pred, t, super_res_latents, prev_timestep=prev_timestep, generator=generator + ).prev_sample + + image = super_res_latents + # done super res + + # post processing + + image = image * 0.5 + 0.5 + image = image.clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) \ No newline at end of file diff --git a/ldm/modules/karlo/kakao/__init__.py b/ldm/modules/karlo/kakao/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/ldm/modules/karlo/kakao/models/__init__.py b/ldm/modules/karlo/kakao/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/ldm/modules/karlo/kakao/models/clip.py b/ldm/modules/karlo/kakao/models/clip.py new file mode 100644 index 0000000..961d815 --- /dev/null +++ b/ldm/modules/karlo/kakao/models/clip.py @@ -0,0 +1,182 @@ +# ------------------------------------------------------------------------------------ +# Karlo-v1.0.alpha +# Copyright (c) 2022 KakaoBrain. All Rights Reserved. +# ------------------------------------------------------------------------------------ +# ------------------------------------------------------------------------------------ +# Adapted from OpenAI's CLIP (https://github.com/openai/CLIP/) +# ------------------------------------------------------------------------------------ + + +import torch +import torch.nn as nn +import torch.nn.functional as F +import clip + +from clip.model import CLIP, convert_weights +from clip.simple_tokenizer import SimpleTokenizer, default_bpe + + +"""===== Monkey-Patching original CLIP for JIT compile =====""" + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + ret = F.layer_norm( + x.type(torch.float32), + self.normalized_shape, + self.weight, + self.bias, + self.eps, + ) + return ret.type(orig_type) + + +clip.model.LayerNorm = LayerNorm +delattr(clip.model.CLIP, "forward") + +"""===== End of Monkey-Patching =====""" + + +class CustomizedCLIP(CLIP): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + @torch.jit.export + def encode_image(self, image): + return self.visual(image) + + @torch.jit.export + def encode_text(self, text): + # re-define this function to return unpooled text features + + x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] + + x = x + self.positional_embedding.type(self.dtype) + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x).type(self.dtype) + + x_seq = x + # x.shape = [batch_size, n_ctx, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + x_out = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection + + return x_out, x_seq + + @torch.jit.ignore + def forward(self, image, text): + super().forward(image, text) + + @classmethod + def load_from_checkpoint(cls, ckpt_path: str): + state_dict = torch.load(ckpt_path, map_location="cpu").state_dict() + + vit = "visual.proj" in state_dict + if vit: + vision_width = state_dict["visual.conv1.weight"].shape[0] + vision_layers = len( + [ + k + for k in state_dict.keys() + if k.startswith("visual.") and k.endswith(".attn.in_proj_weight") + ] + ) + vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] + grid_size = round( + (state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5 + ) + image_resolution = vision_patch_size * grid_size + else: + counts: list = [ + len( + set( + k.split(".")[2] + for k in state_dict + if k.startswith(f"visual.layer{b}") + ) + ) + for b in [1, 2, 3, 4] + ] + vision_layers = tuple(counts) + vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] + output_width = round( + (state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5 + ) + vision_patch_size = None + assert ( + output_width**2 + 1 + == state_dict["visual.attnpool.positional_embedding"].shape[0] + ) + image_resolution = output_width * 32 + + embed_dim = state_dict["text_projection"].shape[1] + context_length = state_dict["positional_embedding"].shape[0] + vocab_size = state_dict["token_embedding.weight"].shape[0] + transformer_width = state_dict["ln_final.weight"].shape[0] + transformer_heads = transformer_width // 64 + transformer_layers = len( + set( + k.split(".")[2] + for k in state_dict + if k.startswith("transformer.resblocks") + ) + ) + + model = cls( + embed_dim, + image_resolution, + vision_layers, + vision_width, + vision_patch_size, + context_length, + vocab_size, + transformer_width, + transformer_heads, + transformer_layers, + ) + + for key in ["input_resolution", "context_length", "vocab_size"]: + if key in state_dict: + del state_dict[key] + + convert_weights(model) + model.load_state_dict(state_dict) + model.eval() + model.float() + return model + + +class CustomizedTokenizer(SimpleTokenizer): + def __init__(self): + super().__init__(bpe_path=default_bpe()) + + self.sot_token = self.encoder["<|startoftext|>"] + self.eot_token = self.encoder["<|endoftext|>"] + + def padded_tokens_and_mask(self, texts, text_ctx): + assert isinstance(texts, list) and all( + isinstance(elem, str) for elem in texts + ), "texts should be a list of strings" + + all_tokens = [ + [self.sot_token] + self.encode(text) + [self.eot_token] for text in texts + ] + + mask = [ + [True] * min(text_ctx, len(tokens)) + + [False] * max(text_ctx - len(tokens), 0) + for tokens in all_tokens + ] + mask = torch.tensor(mask, dtype=torch.bool) + result = torch.zeros(len(all_tokens), text_ctx, dtype=torch.int) + for i, tokens in enumerate(all_tokens): + if len(tokens) > text_ctx: + tokens = tokens[:text_ctx] + tokens[-1] = self.eot_token + result[i, : len(tokens)] = torch.tensor(tokens) + + return result, mask diff --git a/ldm/modules/karlo/kakao/models/decoder_model.py b/ldm/modules/karlo/kakao/models/decoder_model.py new file mode 100644 index 0000000..84e96c9 --- /dev/null +++ b/ldm/modules/karlo/kakao/models/decoder_model.py @@ -0,0 +1,193 @@ +# ------------------------------------------------------------------------------------ +# Karlo-v1.0.alpha +# Copyright (c) 2022 KakaoBrain. All Rights Reserved. +# ------------------------------------------------------------------------------------ + +import copy +import torch + +from ldm.modules.karlo.kakao.modules import create_gaussian_diffusion +from ldm.modules.karlo.kakao.modules.unet import PLMImUNet + + +class Text2ImProgressiveModel(torch.nn.Module): + """ + A decoder that generates 64x64px images based on the text prompt. + + :param config: yaml config to define the decoder. + :param tokenizer: tokenizer used in clip. + """ + + def __init__( + self, + config, + tokenizer, + ): + super().__init__() + + self._conf = config + self._model_conf = config.model.hparams + self._diffusion_kwargs = dict( + steps=config.diffusion.steps, + learn_sigma=config.diffusion.learn_sigma, + sigma_small=config.diffusion.sigma_small, + noise_schedule=config.diffusion.noise_schedule, + use_kl=config.diffusion.use_kl, + predict_xstart=config.diffusion.predict_xstart, + rescale_learned_sigmas=config.diffusion.rescale_learned_sigmas, + timestep_respacing=config.diffusion.timestep_respacing, + ) + self._tokenizer = tokenizer + + self.model = self.create_plm_dec_model() + + cf_token, cf_mask = self.set_cf_text_tensor() + self.register_buffer("cf_token", cf_token, persistent=False) + self.register_buffer("cf_mask", cf_mask, persistent=False) + + @classmethod + def load_from_checkpoint(cls, config, tokenizer, ckpt_path, strict: bool = True): + ckpt = torch.load(ckpt_path, map_location="cpu")["state_dict"] + + model = cls(config, tokenizer) + model.load_state_dict(ckpt, strict=strict) + return model + + def create_plm_dec_model(self): + image_size = self._model_conf.image_size + if self._model_conf.channel_mult == "": + if image_size == 256: + channel_mult = (1, 1, 2, 2, 4, 4) + elif image_size == 128: + channel_mult = (1, 1, 2, 3, 4) + elif image_size == 64: + channel_mult = (1, 2, 3, 4) + else: + raise ValueError(f"unsupported image size: {image_size}") + else: + channel_mult = tuple( + int(ch_mult) for ch_mult in self._model_conf.channel_mult.split(",") + ) + assert 2 ** (len(channel_mult) + 2) == image_size + + attention_ds = [] + for res in self._model_conf.attention_resolutions.split(","): + attention_ds.append(image_size // int(res)) + + return PLMImUNet( + text_ctx=self._model_conf.text_ctx, + xf_width=self._model_conf.xf_width, + in_channels=3, + model_channels=self._model_conf.num_channels, + out_channels=6 if self._model_conf.learn_sigma else 3, + num_res_blocks=self._model_conf.num_res_blocks, + attention_resolutions=tuple(attention_ds), + dropout=self._model_conf.dropout, + channel_mult=channel_mult, + num_heads=self._model_conf.num_heads, + num_head_channels=self._model_conf.num_head_channels, + num_heads_upsample=self._model_conf.num_heads_upsample, + use_scale_shift_norm=self._model_conf.use_scale_shift_norm, + resblock_updown=self._model_conf.resblock_updown, + clip_dim=self._model_conf.clip_dim, + clip_emb_mult=self._model_conf.clip_emb_mult, + clip_emb_type=self._model_conf.clip_emb_type, + clip_emb_drop=self._model_conf.clip_emb_drop, + ) + + def set_cf_text_tensor(self): + return self._tokenizer.padded_tokens_and_mask([""], self.model.text_ctx) + + def get_sample_fn(self, timestep_respacing): + use_ddim = timestep_respacing.startswith(("ddim", "fast")) + + diffusion_kwargs = copy.deepcopy(self._diffusion_kwargs) + diffusion_kwargs.update(timestep_respacing=timestep_respacing) + diffusion = create_gaussian_diffusion(**diffusion_kwargs) + sample_fn = ( + diffusion.ddim_sample_loop_progressive + if use_ddim + else diffusion.p_sample_loop_progressive + ) + + return sample_fn + + def forward( + self, + txt_feat, + txt_feat_seq, + tok, + mask, + img_feat=None, + cf_guidance_scales=None, + timestep_respacing=None, + ): + # cfg should be enabled in inference + assert cf_guidance_scales is not None and all(cf_guidance_scales > 0.0) + assert img_feat is not None + + bsz = txt_feat.shape[0] + img_sz = self._model_conf.image_size + + def guided_model_fn(x_t, ts, **kwargs): + half = x_t[: len(x_t) // 2] + combined = torch.cat([half, half], dim=0) + model_out = self.model(combined, ts, **kwargs) + eps, rest = model_out[:, :3], model_out[:, 3:] + cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) + half_eps = uncond_eps + cf_guidance_scales.view(-1, 1, 1, 1) * ( + cond_eps - uncond_eps + ) + eps = torch.cat([half_eps, half_eps], dim=0) + return torch.cat([eps, rest], dim=1) + + cf_feat = self.model.cf_param.unsqueeze(0) + cf_feat = cf_feat.expand(bsz // 2, -1) + feat = torch.cat([img_feat, cf_feat.to(txt_feat.device)], dim=0) + + cond = { + "y": feat, + "txt_feat": txt_feat, + "txt_feat_seq": txt_feat_seq, + "mask": mask, + } + sample_fn = self.get_sample_fn(timestep_respacing) + sample_outputs = sample_fn( + guided_model_fn, + (bsz, 3, img_sz, img_sz), + noise=None, + device=txt_feat.device, + clip_denoised=True, + model_kwargs=cond, + ) + + for out in sample_outputs: + sample = out["sample"] + yield sample if cf_guidance_scales is None else sample[ + : sample.shape[0] // 2 + ] + + +class Text2ImModel(Text2ImProgressiveModel): + def forward( + self, + txt_feat, + txt_feat_seq, + tok, + mask, + img_feat=None, + cf_guidance_scales=None, + timestep_respacing=None, + ): + last_out = None + for out in super().forward( + txt_feat, + txt_feat_seq, + tok, + mask, + img_feat, + cf_guidance_scales, + timestep_respacing, + ): + last_out = out + return last_out diff --git a/ldm/modules/karlo/kakao/models/prior_model.py b/ldm/modules/karlo/kakao/models/prior_model.py new file mode 100644 index 0000000..03ef230 --- /dev/null +++ b/ldm/modules/karlo/kakao/models/prior_model.py @@ -0,0 +1,138 @@ +# ------------------------------------------------------------------------------------ +# Karlo-v1.0.alpha +# Copyright (c) 2022 KakaoBrain. All Rights Reserved. +# ------------------------------------------------------------------------------------ + +import copy +import torch + +from ldm.modules.karlo.kakao.modules import create_gaussian_diffusion +from ldm.modules.karlo.kakao.modules.xf import PriorTransformer + + +class PriorDiffusionModel(torch.nn.Module): + """ + A prior that generates clip image feature based on the text prompt. + + :param config: yaml config to define the decoder. + :param tokenizer: tokenizer used in clip. + :param clip_mean: mean to normalize the clip image feature (zero-mean, unit variance). + :param clip_std: std to noramlize the clip image feature (zero-mean, unit variance). + """ + + def __init__(self, config, tokenizer, clip_mean, clip_std): + super().__init__() + + self._conf = config + self._model_conf = config.model.hparams + self._diffusion_kwargs = dict( + steps=config.diffusion.steps, + learn_sigma=config.diffusion.learn_sigma, + sigma_small=config.diffusion.sigma_small, + noise_schedule=config.diffusion.noise_schedule, + use_kl=config.diffusion.use_kl, + predict_xstart=config.diffusion.predict_xstart, + rescale_learned_sigmas=config.diffusion.rescale_learned_sigmas, + timestep_respacing=config.diffusion.timestep_respacing, + ) + self._tokenizer = tokenizer + + self.register_buffer("clip_mean", clip_mean[None, :], persistent=False) + self.register_buffer("clip_std", clip_std[None, :], persistent=False) + + causal_mask = self.get_causal_mask() + self.register_buffer("causal_mask", causal_mask, persistent=False) + + self.model = PriorTransformer( + text_ctx=self._model_conf.text_ctx, + xf_width=self._model_conf.xf_width, + xf_layers=self._model_conf.xf_layers, + xf_heads=self._model_conf.xf_heads, + xf_final_ln=self._model_conf.xf_final_ln, + clip_dim=self._model_conf.clip_dim, + ) + + cf_token, cf_mask = self.set_cf_text_tensor() + self.register_buffer("cf_token", cf_token, persistent=False) + self.register_buffer("cf_mask", cf_mask, persistent=False) + + @classmethod + def load_from_checkpoint( + cls, config, tokenizer, clip_mean, clip_std, ckpt_path, strict: bool = True + ): + ckpt = torch.load(ckpt_path, map_location="cpu")["state_dict"] + + model = cls(config, tokenizer, clip_mean, clip_std) + model.load_state_dict(ckpt, strict=strict) + return model + + def set_cf_text_tensor(self): + return self._tokenizer.padded_tokens_and_mask([""], self.model.text_ctx) + + def get_sample_fn(self, timestep_respacing): + use_ddim = timestep_respacing.startswith(("ddim", "fast")) + + diffusion_kwargs = copy.deepcopy(self._diffusion_kwargs) + diffusion_kwargs.update(timestep_respacing=timestep_respacing) + diffusion = create_gaussian_diffusion(**diffusion_kwargs) + sample_fn = diffusion.ddim_sample_loop if use_ddim else diffusion.p_sample_loop + + return sample_fn + + def get_causal_mask(self): + seq_len = self._model_conf.text_ctx + 4 + mask = torch.empty(seq_len, seq_len) + mask.fill_(float("-inf")) + mask.triu_(1) + mask = mask[None, ...] + return mask + + def forward( + self, + txt_feat, + txt_feat_seq, + mask, + cf_guidance_scales=None, + timestep_respacing=None, + denoised_fn=True, + ): + # cfg should be enabled in inference + assert cf_guidance_scales is not None and all(cf_guidance_scales > 0.0) + + bsz_ = txt_feat.shape[0] + bsz = bsz_ // 2 + + def guided_model_fn(x_t, ts, **kwargs): + half = x_t[: len(x_t) // 2] + combined = torch.cat([half, half], dim=0) + model_out = self.model(combined, ts, **kwargs) + eps, rest = ( + model_out[:, : int(x_t.shape[1])], + model_out[:, int(x_t.shape[1]) :], + ) + cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) + half_eps = uncond_eps + cf_guidance_scales.view(-1, 1) * ( + cond_eps - uncond_eps + ) + eps = torch.cat([half_eps, half_eps], dim=0) + return torch.cat([eps, rest], dim=1) + + cond = { + "text_emb": txt_feat, + "text_enc": txt_feat_seq, + "mask": mask, + "causal_mask": self.causal_mask, + } + sample_fn = self.get_sample_fn(timestep_respacing) + sample = sample_fn( + guided_model_fn, + (bsz_, self.model.clip_dim), + noise=None, + device=txt_feat.device, + clip_denoised=False, + denoised_fn=lambda x: torch.clamp(x, -10, 10), + model_kwargs=cond, + ) + sample = (sample * self.clip_std) + self.clip_mean + + return sample[:bsz] diff --git a/ldm/modules/karlo/kakao/models/sr_256_1k.py b/ldm/modules/karlo/kakao/models/sr_256_1k.py new file mode 100644 index 0000000..1e874f6 --- /dev/null +++ b/ldm/modules/karlo/kakao/models/sr_256_1k.py @@ -0,0 +1,10 @@ +# ------------------------------------------------------------------------------------ +# Karlo-v1.0.alpha +# Copyright (c) 2022 KakaoBrain. All Rights Reserved. +# ------------------------------------------------------------------------------------ + +from ldm.modules.karlo.kakao.models.sr_64_256 import SupRes64to256Progressive + + +class SupRes256to1kProgressive(SupRes64to256Progressive): + pass # no difference currently diff --git a/ldm/modules/karlo/kakao/models/sr_64_256.py b/ldm/modules/karlo/kakao/models/sr_64_256.py new file mode 100644 index 0000000..32687af --- /dev/null +++ b/ldm/modules/karlo/kakao/models/sr_64_256.py @@ -0,0 +1,88 @@ +# ------------------------------------------------------------------------------------ +# Karlo-v1.0.alpha +# Copyright (c) 2022 KakaoBrain. All Rights Reserved. +# ------------------------------------------------------------------------------------ + +import copy +import torch + +from ldm.modules.karlo.kakao.modules.unet import SuperResUNetModel +from ldm.modules.karlo.kakao.modules import create_gaussian_diffusion + + +class ImprovedSupRes64to256ProgressiveModel(torch.nn.Module): + """ + ImprovedSR model fine-tunes the pretrained DDPM-based SR model by using adversarial and perceptual losses. + In specific, the low-resolution sample is iteratively recovered by 6 steps with the frozen pretrained SR model. + In the following additional one step, a seperate fine-tuned model recovers high-frequency details. + This approach greatly improves the fidelity of images of 256x256px, even with small number of reverse steps. + """ + + def __init__(self, config): + super().__init__() + + self._config = config + self._diffusion_kwargs = dict( + steps=config.diffusion.steps, + learn_sigma=config.diffusion.learn_sigma, + sigma_small=config.diffusion.sigma_small, + noise_schedule=config.diffusion.noise_schedule, + use_kl=config.diffusion.use_kl, + predict_xstart=config.diffusion.predict_xstart, + rescale_learned_sigmas=config.diffusion.rescale_learned_sigmas, + ) + + self.model_first_steps = SuperResUNetModel( + in_channels=3, # auto-changed to 6 inside the model + model_channels=config.model.hparams.channels, + out_channels=3, + num_res_blocks=config.model.hparams.depth, + attention_resolutions=(), # no attention + dropout=config.model.hparams.dropout, + channel_mult=config.model.hparams.channels_multiple, + resblock_updown=True, + use_middle_attention=False, + ) + self.model_last_step = SuperResUNetModel( + in_channels=3, # auto-changed to 6 inside the model + model_channels=config.model.hparams.channels, + out_channels=3, + num_res_blocks=config.model.hparams.depth, + attention_resolutions=(), # no attention + dropout=config.model.hparams.dropout, + channel_mult=config.model.hparams.channels_multiple, + resblock_updown=True, + use_middle_attention=False, + ) + + @classmethod + def load_from_checkpoint(cls, config, ckpt_path, strict: bool = True): + ckpt = torch.load(ckpt_path, map_location="cpu")["state_dict"] + + model = cls(config) + model.load_state_dict(ckpt, strict=strict) + return model + + def get_sample_fn(self, timestep_respacing): + diffusion_kwargs = copy.deepcopy(self._diffusion_kwargs) + diffusion_kwargs.update(timestep_respacing=timestep_respacing) + diffusion = create_gaussian_diffusion(**diffusion_kwargs) + return diffusion.p_sample_loop_progressive_for_improved_sr + + def forward(self, low_res, timestep_respacing="7", **kwargs): + assert ( + timestep_respacing == "7" + ), "different respacing method may work, but no guaranteed" + + sample_fn = self.get_sample_fn(timestep_respacing) + sample_outputs = sample_fn( + self.model_first_steps, + self.model_last_step, + shape=low_res.shape, + clip_denoised=True, + model_kwargs=dict(low_res=low_res), + **kwargs, + ) + for x in sample_outputs: + sample = x["sample"] + yield sample diff --git a/ldm/modules/karlo/kakao/modules/__init__.py b/ldm/modules/karlo/kakao/modules/__init__.py new file mode 100644 index 0000000..11d4358 --- /dev/null +++ b/ldm/modules/karlo/kakao/modules/__init__.py @@ -0,0 +1,49 @@ +# ------------------------------------------------------------------------------------ +# Adapted from Guided-Diffusion repo (https://github.com/openai/guided-diffusion) +# ------------------------------------------------------------------------------------ + + +from .diffusion import gaussian_diffusion as gd +from .diffusion.respace import ( + SpacedDiffusion, + space_timesteps, +) + + +def create_gaussian_diffusion( + steps, + learn_sigma, + sigma_small, + noise_schedule, + use_kl, + predict_xstart, + rescale_learned_sigmas, + timestep_respacing, +): + betas = gd.get_named_beta_schedule(noise_schedule, steps) + if use_kl: + loss_type = gd.LossType.RESCALED_KL + elif rescale_learned_sigmas: + loss_type = gd.LossType.RESCALED_MSE + else: + loss_type = gd.LossType.MSE + if not timestep_respacing: + timestep_respacing = [steps] + + return SpacedDiffusion( + use_timesteps=space_timesteps(steps, timestep_respacing), + betas=betas, + model_mean_type=( + gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X + ), + model_var_type=( + ( + gd.ModelVarType.FIXED_LARGE + if not sigma_small + else gd.ModelVarType.FIXED_SMALL + ) + if not learn_sigma + else gd.ModelVarType.LEARNED_RANGE + ), + loss_type=loss_type, + ) diff --git a/ldm/modules/karlo/kakao/modules/diffusion/gaussian_diffusion.py b/ldm/modules/karlo/kakao/modules/diffusion/gaussian_diffusion.py new file mode 100644 index 0000000..6a111aa --- /dev/null +++ b/ldm/modules/karlo/kakao/modules/diffusion/gaussian_diffusion.py @@ -0,0 +1,828 @@ +# ------------------------------------------------------------------------------------ +# Adapted from Guided-Diffusion repo (https://github.com/openai/guided-diffusion) +# ------------------------------------------------------------------------------------ + +import enum +import math + +import numpy as np +import torch as th + + +def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac): + betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) + warmup_time = int(num_diffusion_timesteps * warmup_frac) + betas[:warmup_time] = np.linspace( + beta_start, beta_end, warmup_time, dtype=np.float64 + ) + return betas + + +def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps): + """ + This is the deprecated API for creating beta schedules. + See get_named_beta_schedule() for the new library of schedules. + """ + if beta_schedule == "quad": + betas = ( + np.linspace( + beta_start**0.5, + beta_end**0.5, + num_diffusion_timesteps, + dtype=np.float64, + ) + ** 2 + ) + elif beta_schedule == "linear": + betas = np.linspace( + beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64 + ) + elif beta_schedule == "warmup10": + betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1) + elif beta_schedule == "warmup50": + betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5) + elif beta_schedule == "const": + betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) + elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1 + betas = 1.0 / np.linspace( + num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64 + ) + else: + raise NotImplementedError(beta_schedule) + assert betas.shape == (num_diffusion_timesteps,) + return betas + + +def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): + """ + Get a pre-defined beta schedule for the given name. + The beta schedule library consists of beta schedules which remain similar + in the limit of num_diffusion_timesteps. + Beta schedules may be added, but should not be removed or changed once + they are committed to maintain backwards compatibility. + """ + if schedule_name == "linear": + # Linear schedule from Ho et al, extended to work for any number of + # diffusion steps. + scale = 1000 / num_diffusion_timesteps + return get_beta_schedule( + "linear", + beta_start=scale * 0.0001, + beta_end=scale * 0.02, + num_diffusion_timesteps=num_diffusion_timesteps, + ) + elif schedule_name == "squaredcos_cap_v2": + return betas_for_alpha_bar( + num_diffusion_timesteps, + lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, + ) + else: + raise NotImplementedError(f"unknown beta schedule: {schedule_name}") + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +class ModelMeanType(enum.Enum): + """ + Which type of output the model predicts. + """ + + PREVIOUS_X = enum.auto() # the model predicts x_{t-1} + START_X = enum.auto() # the model predicts x_0 + EPSILON = enum.auto() # the model predicts epsilon + + +class ModelVarType(enum.Enum): + """ + What is used as the model's output variance. + The LEARNED_RANGE option has been added to allow the model to predict + values between FIXED_SMALL and FIXED_LARGE, making its job easier. + """ + + LEARNED = enum.auto() + FIXED_SMALL = enum.auto() + FIXED_LARGE = enum.auto() + LEARNED_RANGE = enum.auto() + + +class LossType(enum.Enum): + MSE = enum.auto() # use raw MSE loss (and KL when learning variances) + RESCALED_MSE = ( + enum.auto() + ) # use raw MSE loss (with RESCALED_KL when learning variances) + KL = enum.auto() # use the variational lower-bound + RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB + + def is_vb(self): + return self == LossType.KL or self == LossType.RESCALED_KL + + +class GaussianDiffusion(th.nn.Module): + """ + Utilities for training and sampling diffusion models. + Original ported from this codebase: + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42 + :param betas: a 1-D numpy array of betas for each diffusion timestep, + starting at T and going to 1. + """ + + def __init__( + self, + *, + betas, + model_mean_type, + model_var_type, + loss_type, + ): + super(GaussianDiffusion, self).__init__() + self.model_mean_type = model_mean_type + self.model_var_type = model_var_type + self.loss_type = loss_type + + # Use float64 for accuracy. + betas = np.array(betas, dtype=np.float64) + assert len(betas.shape) == 1, "betas must be 1-D" + assert (betas > 0).all() and (betas <= 1).all() + + self.num_timesteps = int(betas.shape[0]) + + alphas = 1.0 - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1]) + alphas_cumprod_next = np.append(alphas_cumprod[1:], 0.0) + assert alphas_cumprod_prev.shape == (self.num_timesteps,) + + # calculations for diffusion q(x_t | x_{t-1}) and others + sqrt_alphas_cumprod = np.sqrt(alphas_cumprod) + sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - alphas_cumprod) + log_one_minus_alphas_cumprod = np.log(1.0 - alphas_cumprod) + sqrt_recip_alphas_cumprod = np.sqrt(1.0 / alphas_cumprod) + sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / alphas_cumprod - 1) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + posterior_variance = ( + betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) + ) + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + posterior_log_variance_clipped = np.log( + np.append(posterior_variance[1], posterior_variance[1:]) + ) + posterior_mean_coef1 = ( + betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod) + ) + posterior_mean_coef2 = ( + (1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod) + ) + + self.register_buffer("betas", th.from_numpy(betas), persistent=False) + self.register_buffer( + "alphas_cumprod", th.from_numpy(alphas_cumprod), persistent=False + ) + self.register_buffer( + "alphas_cumprod_prev", th.from_numpy(alphas_cumprod_prev), persistent=False + ) + self.register_buffer( + "alphas_cumprod_next", th.from_numpy(alphas_cumprod_next), persistent=False + ) + + self.register_buffer( + "sqrt_alphas_cumprod", th.from_numpy(sqrt_alphas_cumprod), persistent=False + ) + self.register_buffer( + "sqrt_one_minus_alphas_cumprod", + th.from_numpy(sqrt_one_minus_alphas_cumprod), + persistent=False, + ) + self.register_buffer( + "log_one_minus_alphas_cumprod", + th.from_numpy(log_one_minus_alphas_cumprod), + persistent=False, + ) + self.register_buffer( + "sqrt_recip_alphas_cumprod", + th.from_numpy(sqrt_recip_alphas_cumprod), + persistent=False, + ) + self.register_buffer( + "sqrt_recipm1_alphas_cumprod", + th.from_numpy(sqrt_recipm1_alphas_cumprod), + persistent=False, + ) + + self.register_buffer( + "posterior_variance", th.from_numpy(posterior_variance), persistent=False + ) + self.register_buffer( + "posterior_log_variance_clipped", + th.from_numpy(posterior_log_variance_clipped), + persistent=False, + ) + self.register_buffer( + "posterior_mean_coef1", + th.from_numpy(posterior_mean_coef1), + persistent=False, + ) + self.register_buffer( + "posterior_mean_coef2", + th.from_numpy(posterior_mean_coef2), + persistent=False, + ) + + def q_mean_variance(self, x_start, t): + """ + Get the distribution q(x_t | x_0). + :param x_start: the [N x C x ...] tensor of noiseless inputs. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :return: A tuple (mean, variance, log_variance), all of x_start's shape. + """ + mean = ( + _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + ) + variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) + log_variance = _extract_into_tensor( + self.log_one_minus_alphas_cumprod, t, x_start.shape + ) + return mean, variance, log_variance + + def q_sample(self, x_start, t, noise=None): + """ + Diffuse the data for a given number of diffusion steps. + In other words, sample from q(x_t | x_0). + :param x_start: the initial data batch. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :param noise: if specified, the split-out normal noise. + :return: A noisy version of x_start. + """ + if noise is None: + noise = th.randn_like(x_start) + assert noise.shape == x_start.shape + return ( + _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) + * noise + ) + + def q_posterior_mean_variance(self, x_start, x_t, t): + """ + Compute the mean and variance of the diffusion posterior: + q(x_{t-1} | x_t, x_0) + """ + assert x_start.shape == x_t.shape + posterior_mean = ( + _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = _extract_into_tensor( + self.posterior_log_variance_clipped, t, x_t.shape + ) + assert ( + posterior_mean.shape[0] + == posterior_variance.shape[0] + == posterior_log_variance_clipped.shape[0] + == x_start.shape[0] + ) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + model_kwargs=None, + **ignore_kwargs, + ): + """ + Apply the model to get p(x_{t-1} | x_t), as well as a prediction of + the initial x, x_0. + :param model: the model, which takes a signal and a batch of timesteps + as input. + :param x: the [N x C x ...] tensor at time t. + :param t: a 1-D Tensor of timesteps. + :param clip_denoised: if True, clip the denoised signal into [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. Applies before + clip_denoised. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict with the following keys: + - 'mean': the model mean output. + - 'variance': the model variance output. + - 'log_variance': the log of 'variance'. + - 'pred_xstart': the prediction for x_0. + """ + if model_kwargs is None: + model_kwargs = {} + + B, C = x.shape[:2] + assert t.shape == (B,) + model_output = model(x, t, **model_kwargs) + if isinstance(model_output, tuple): + model_output, extra = model_output + else: + extra = None + + if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]: + assert model_output.shape == (B, C * 2, *x.shape[2:]) + model_output, model_var_values = th.split(model_output, C, dim=1) + if self.model_var_type == ModelVarType.LEARNED: + model_log_variance = model_var_values + model_variance = th.exp(model_log_variance) + else: + min_log = _extract_into_tensor( + self.posterior_log_variance_clipped, t, x.shape + ) + max_log = _extract_into_tensor(th.log(self.betas), t, x.shape) + # The model_var_values is [-1, 1] for [min_var, max_var]. + frac = (model_var_values + 1) / 2 + model_log_variance = frac * max_log + (1 - frac) * min_log + model_variance = th.exp(model_log_variance) + else: + model_variance, model_log_variance = { + # for fixedlarge, we set the initial (log-)variance like so + # to get a better decoder log likelihood. + ModelVarType.FIXED_LARGE: ( + th.cat([self.posterior_variance[1][None], self.betas[1:]]), + th.log(th.cat([self.posterior_variance[1][None], self.betas[1:]])), + ), + ModelVarType.FIXED_SMALL: ( + self.posterior_variance, + self.posterior_log_variance_clipped, + ), + }[self.model_var_type] + model_variance = _extract_into_tensor(model_variance, t, x.shape) + model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape) + + def process_xstart(x): + if denoised_fn is not None: + x = denoised_fn(x) + if clip_denoised: + return x.clamp(-1, 1) + return x + + if self.model_mean_type == ModelMeanType.PREVIOUS_X: + pred_xstart = process_xstart( + self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output) + ) + model_mean = model_output + elif self.model_mean_type in [ModelMeanType.START_X, ModelMeanType.EPSILON]: + if self.model_mean_type == ModelMeanType.START_X: + pred_xstart = process_xstart(model_output) + else: + pred_xstart = process_xstart( + self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output) + ) + model_mean, _, _ = self.q_posterior_mean_variance( + x_start=pred_xstart, x_t=x, t=t + ) + else: + raise NotImplementedError(self.model_mean_type) + + assert ( + model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape + ) + return { + "mean": model_mean, + "variance": model_variance, + "log_variance": model_log_variance, + "pred_xstart": pred_xstart, + } + + def _predict_xstart_from_eps(self, x_t, t, eps): + assert x_t.shape == eps.shape + return ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t + - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps + ) + + def _predict_eps_from_xstart(self, x_t, t, pred_xstart): + return ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t + - pred_xstart + ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + + def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None): + """ + Compute the mean for the previous step, given a function cond_fn that + computes the gradient of a conditional log probability with respect to + x. In particular, cond_fn computes grad(log(p(y|x))), and we want to + condition on y. + This uses the conditioning strategy from Sohl-Dickstein et al. (2015). + """ + gradient = cond_fn(x, t, **model_kwargs) + new_mean = ( + p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float() + ) + return new_mean + + def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None): + """ + Compute what the p_mean_variance output would have been, should the + model's score function be conditioned by cond_fn. + See condition_mean() for details on cond_fn. + Unlike condition_mean(), this instead uses the conditioning strategy + from Song et al (2020). + """ + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + + eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"]) + eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs) + + out = p_mean_var.copy() + out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps) + out["mean"], _, _ = self.q_posterior_mean_variance( + x_start=out["pred_xstart"], x_t=x, t=t + ) + return out + + def p_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + ): + """ + Sample x_{t-1} from the model at the given timestep. + :param model: the model to sample from. + :param x: the current tensor at x_{t-1}. + :param t: the value of t, starting at 0 for the first diffusion step. + :param clip_denoised: if True, clip the x_start prediction to [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. + :param cond_fn: if not None, this is a gradient function that acts + similarly to the model. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict containing the following keys: + - 'sample': a random sample from the model. + - 'pred_xstart': a prediction of x_0. + """ + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + noise = th.randn_like(x) + nonzero_mask = ( + (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) + ) # no noise when t == 0 + if cond_fn is not None: + out["mean"] = self.condition_mean( + cond_fn, out, x, t, model_kwargs=model_kwargs + ) + sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise + return {"sample": sample, "pred_xstart": out["pred_xstart"]} + + def p_sample_loop( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + ): + """ + Generate samples from the model. + :param model: the model module. + :param shape: the shape of the samples, (N, C, H, W). + :param noise: if specified, the noise from the encoder to sample. + Should be of the same shape as `shape`. + :param clip_denoised: if True, clip x_start predictions to [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. + :param cond_fn: if not None, this is a gradient function that acts + similarly to the model. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :param device: if specified, the device to create the samples on. + If not specified, use a model parameter's device. + :param progress: if True, show a tqdm progress bar. + :return: a non-differentiable batch of samples. + """ + final = None + for sample in self.p_sample_loop_progressive( + model, + shape, + noise=noise, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + device=device, + progress=progress, + ): + final = sample + return final["sample"] + + def p_sample_loop_progressive( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + ): + """ + Generate samples from the model and yield intermediate samples from + each timestep of diffusion. + Arguments are the same as p_sample_loop(). + Returns a generator over dicts, where each dict is the return value of + p_sample(). + """ + if device is None: + device = next(model.parameters()).device + assert isinstance(shape, (tuple, list)) + if noise is not None: + img = noise + else: + img = th.randn(*shape, device=device) + indices = list(range(self.num_timesteps))[::-1] + + if progress: + # Lazy import so that we don't depend on tqdm. + from tqdm.auto import tqdm + + indices = tqdm(indices) + + for idx, i in enumerate(indices): + t = th.tensor([i] * shape[0], device=device) + with th.no_grad(): + out = self.p_sample( + model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + ) + yield out + img = out["sample"] + + def p_sample_loop_progressive_for_improved_sr( + self, + model, + model_aux, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + ): + """ + Modified version of p_sample_loop_progressive for sampling from the improved sr model + """ + + if device is None: + device = next(model.parameters()).device + assert isinstance(shape, (tuple, list)) + if noise is not None: + img = noise + else: + img = th.randn(*shape, device=device) + indices = list(range(self.num_timesteps))[::-1] + + if progress: + # Lazy import so that we don't depend on tqdm. + from tqdm.auto import tqdm + + indices = tqdm(indices) + + for idx, i in enumerate(indices): + t = th.tensor([i] * shape[0], device=device) + with th.no_grad(): + out = self.p_sample( + model_aux if len(indices) - 1 == idx else model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + ) + yield out + img = out["sample"] + + def ddim_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + eta=0.0, + ): + """ + Sample x_{t-1} from the model using DDIM. + Same usage as p_sample(). + """ + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + if cond_fn is not None: + out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs) + + # Usually our model outputs epsilon, but we re-derive it + # in case we used x_start or x_prev prediction. + eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) + + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) + sigma = ( + eta + * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) + * th.sqrt(1 - alpha_bar / alpha_bar_prev) + ) + # Equation 12. + noise = th.randn_like(x) + mean_pred = ( + out["pred_xstart"] * th.sqrt(alpha_bar_prev) + + th.sqrt(1 - alpha_bar_prev - sigma**2) * eps + ) + nonzero_mask = ( + (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) + ) # no noise when t == 0 + sample = mean_pred + nonzero_mask * sigma * noise + return {"sample": sample, "pred_xstart": out["pred_xstart"]} + + def ddim_reverse_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + eta=0.0, + ): + """ + Sample x_{t+1} from the model using DDIM reverse ODE. + """ + assert eta == 0.0, "Reverse ODE only for deterministic path" + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + if cond_fn is not None: + out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs) + # Usually our model outputs epsilon, but we re-derive it + # in case we used x_start or x_prev prediction. + eps = ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x + - out["pred_xstart"] + ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape) + alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape) + + # Equation 12. reversed + mean_pred = ( + out["pred_xstart"] * th.sqrt(alpha_bar_next) + + th.sqrt(1 - alpha_bar_next) * eps + ) + + return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]} + + def ddim_sample_loop( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + eta=0.0, + ): + """ + Generate samples from the model using DDIM. + Same usage as p_sample_loop(). + """ + final = None + for sample in self.ddim_sample_loop_progressive( + model, + shape, + noise=noise, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + device=device, + progress=progress, + eta=eta, + ): + final = sample + return final["sample"] + + def ddim_sample_loop_progressive( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + eta=0.0, + ): + """ + Use DDIM to sample from the model and yield intermediate samples from + each timestep of DDIM. + Same usage as p_sample_loop_progressive(). + """ + if device is None: + device = next(model.parameters()).device + assert isinstance(shape, (tuple, list)) + if noise is not None: + img = noise + else: + img = th.randn(*shape, device=device) + indices = list(range(self.num_timesteps))[::-1] + + if progress: + # Lazy import so that we don't depend on tqdm. + from tqdm.auto import tqdm + + indices = tqdm(indices) + + for i in indices: + t = th.tensor([i] * shape[0], device=device) + with th.no_grad(): + out = self.ddim_sample( + model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + eta=eta, + ) + yield out + img = out["sample"] + + +def _extract_into_tensor(arr, timesteps, broadcast_shape): + """ + Extract values from a 1-D numpy array for a batch of indices. + :param arr: the 1-D numpy array. + :param timesteps: a tensor of indices into the array to extract. + :param broadcast_shape: a larger shape of K dimensions with the batch + dimension equal to the length of timesteps. + :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. + """ + res = arr.to(device=timesteps.device)[timesteps].float() + while len(res.shape) < len(broadcast_shape): + res = res[..., None] + return res + th.zeros(broadcast_shape, device=timesteps.device) diff --git a/ldm/modules/karlo/kakao/modules/diffusion/respace.py b/ldm/modules/karlo/kakao/modules/diffusion/respace.py new file mode 100644 index 0000000..70c808f --- /dev/null +++ b/ldm/modules/karlo/kakao/modules/diffusion/respace.py @@ -0,0 +1,112 @@ +# ------------------------------------------------------------------------------------ +# Adapted from Guided-Diffusion repo (https://github.com/openai/guided-diffusion) +# ------------------------------------------------------------------------------------ + + +import torch as th + +from .gaussian_diffusion import GaussianDiffusion + + +def space_timesteps(num_timesteps, section_counts): + """ + Create a list of timesteps to use from an original diffusion process, + given the number of timesteps we want to take from equally-sized portions + of the original process. + + For example, if there's 300 timesteps and the section counts are [10,15,20] + then the first 100 timesteps are strided to be 10 timesteps, the second 100 + are strided to be 15 timesteps, and the final 100 are strided to be 20. + + :param num_timesteps: the number of diffusion steps in the original + process to divide up. + :param section_counts: either a list of numbers, or a string containing + comma-separated numbers, indicating the step count + per section. As a special case, use "ddimN" where N + is a number of steps to use the striding from the + DDIM paper. + :return: a set of diffusion steps from the original process to use. + """ + if isinstance(section_counts, str): + if section_counts.startswith("ddim"): + desired_count = int(section_counts[len("ddim") :]) + for i in range(1, num_timesteps): + if len(range(0, num_timesteps, i)) == desired_count: + return set(range(0, num_timesteps, i)) + raise ValueError( + f"cannot create exactly {num_timesteps} steps with an integer stride" + ) + elif section_counts == "fast27": + steps = space_timesteps(num_timesteps, "10,10,3,2,2") + # Help reduce DDIM artifacts from noisiest timesteps. + steps.remove(num_timesteps - 1) + steps.add(num_timesteps - 3) + return steps + section_counts = [int(x) for x in section_counts.split(",")] + size_per = num_timesteps // len(section_counts) + extra = num_timesteps % len(section_counts) + start_idx = 0 + all_steps = [] + for i, section_count in enumerate(section_counts): + size = size_per + (1 if i < extra else 0) + if size < section_count: + raise ValueError( + f"cannot divide section of {size} steps into {section_count}" + ) + if section_count <= 1: + frac_stride = 1 + else: + frac_stride = (size - 1) / (section_count - 1) + cur_idx = 0.0 + taken_steps = [] + for _ in range(section_count): + taken_steps.append(start_idx + round(cur_idx)) + cur_idx += frac_stride + all_steps += taken_steps + start_idx += size + return set(all_steps) + + +class SpacedDiffusion(GaussianDiffusion): + """ + A diffusion process which can skip steps in a base diffusion process. + + :param use_timesteps: a collection (sequence or set) of timesteps from the + original diffusion process to retain. + :param kwargs: the kwargs to create the base diffusion process. + """ + + def __init__(self, use_timesteps, **kwargs): + self.use_timesteps = set(use_timesteps) + self.original_num_steps = len(kwargs["betas"]) + + base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa + last_alpha_cumprod = 1.0 + new_betas = [] + timestep_map = [] + for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): + if i in self.use_timesteps: + new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) + last_alpha_cumprod = alpha_cumprod + timestep_map.append(i) + kwargs["betas"] = th.tensor(new_betas).numpy() + super().__init__(**kwargs) + self.register_buffer("timestep_map", th.tensor(timestep_map), persistent=False) + + def p_mean_variance(self, model, *args, **kwargs): + return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) + + def condition_mean(self, cond_fn, *args, **kwargs): + return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) + + def condition_score(self, cond_fn, *args, **kwargs): + return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) + + def _wrap_model(self, model): + def wrapped(x, ts, **kwargs): + ts_cpu = ts.detach().to("cpu") + return model( + x, self.timestep_map[ts_cpu].to(device=ts.device, dtype=ts.dtype), **kwargs + ) + + return wrapped diff --git a/ldm/modules/karlo/kakao/modules/nn.py b/ldm/modules/karlo/kakao/modules/nn.py new file mode 100644 index 0000000..2eef3f5 --- /dev/null +++ b/ldm/modules/karlo/kakao/modules/nn.py @@ -0,0 +1,114 @@ +# ------------------------------------------------------------------------------------ +# Adapted from Guided-Diffusion repo (https://github.com/openai/guided-diffusion) +# ------------------------------------------------------------------------------------ + +import math + +import torch as th +import torch.nn as nn +import torch.nn.functional as F + + +class GroupNorm32(nn.GroupNorm): + def __init__(self, num_groups, num_channels, swish, eps=1e-5): + super().__init__(num_groups=num_groups, num_channels=num_channels, eps=eps) + self.swish = swish + + def forward(self, x): + y = super().forward(x.float()).to(x.dtype) + if self.swish == 1.0: + y = F.silu(y) + elif self.swish: + y = y * F.sigmoid(y * float(self.swish)) + return y + + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return nn.Linear(*args, **kwargs) + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def normalization(channels, swish=0.0): + """ + Make a standard normalization layer, with an optional swish activation. + + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + return GroupNorm32(num_channels=channels, num_groups=32, swish=swish) + + +def timestep_embedding(timesteps, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + half = dim // 2 + freqs = th.exp( + -math.log(max_period) + * th.arange(start=0, end=half, dtype=th.float32, device=timesteps.device) + / half + ) + args = timesteps[:, None].float() * freqs[None] + embedding = th.cat([th.cos(args), th.sin(args)], dim=-1) + if dim % 2: + embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) diff --git a/ldm/modules/karlo/kakao/modules/resample.py b/ldm/modules/karlo/kakao/modules/resample.py new file mode 100644 index 0000000..485421a --- /dev/null +++ b/ldm/modules/karlo/kakao/modules/resample.py @@ -0,0 +1,68 @@ +# ------------------------------------------------------------------------------------ +# Modified from Guided-Diffusion (https://github.com/openai/guided-diffusion) +# ------------------------------------------------------------------------------------ + +from abc import abstractmethod + +import torch as th + + +def create_named_schedule_sampler(name, diffusion): + """ + Create a ScheduleSampler from a library of pre-defined samplers. + + :param name: the name of the sampler. + :param diffusion: the diffusion object to sample for. + """ + if name == "uniform": + return UniformSampler(diffusion) + else: + raise NotImplementedError(f"unknown schedule sampler: {name}") + + +class ScheduleSampler(th.nn.Module): + """ + A distribution over timesteps in the diffusion process, intended to reduce + variance of the objective. + + By default, samplers perform unbiased importance sampling, in which the + objective's mean is unchanged. + However, subclasses may override sample() to change how the resampled + terms are reweighted, allowing for actual changes in the objective. + """ + + @abstractmethod + def weights(self): + """ + Get a numpy array of weights, one per diffusion step. + + The weights needn't be normalized, but must be positive. + """ + + def sample(self, batch_size, device): + """ + Importance-sample timesteps for a batch. + + :param batch_size: the number of timesteps. + :param device: the torch device to save to. + :return: a tuple (timesteps, weights): + - timesteps: a tensor of timestep indices. + - weights: a tensor of weights to scale the resulting losses. + """ + w = self.weights() + p = w / th.sum(w) + indices = p.multinomial(batch_size, replacement=True) + weights = 1 / (len(p) * p[indices]) + return indices, weights + + +class UniformSampler(ScheduleSampler): + def __init__(self, diffusion): + super(UniformSampler, self).__init__() + self.diffusion = diffusion + self.register_buffer( + "_weights", th.ones([diffusion.num_timesteps]), persistent=False + ) + + def weights(self): + return self._weights diff --git a/ldm/modules/karlo/kakao/modules/unet.py b/ldm/modules/karlo/kakao/modules/unet.py new file mode 100644 index 0000000..c99d0b7 --- /dev/null +++ b/ldm/modules/karlo/kakao/modules/unet.py @@ -0,0 +1,792 @@ +# ------------------------------------------------------------------------------------ +# Modified from Guided-Diffusion (https://github.com/openai/guided-diffusion) +# ------------------------------------------------------------------------------------ + +import math +from abc import abstractmethod + +import torch as th +import torch.nn as nn +import torch.nn.functional as F + +from .nn import ( + avg_pool_nd, + conv_nd, + linear, + normalization, + timestep_embedding, + zero_module, +) +from .xf import LayerNorm + + +class TimestepBlock(nn.Module): + """ + Any module where forward() takes timestep embeddings as a second argument. + """ + + @abstractmethod + def forward(self, x, emb): + """ + Apply the module to `x` given `emb` timestep embeddings. + """ + + +class TimestepEmbedSequential(nn.Sequential, TimestepBlock): + """ + A sequential module that passes timestep embeddings to the children that + support it as an extra input. + """ + + def forward(self, x, emb, encoder_out=None, mask=None): + for layer in self: + if isinstance(layer, TimestepBlock): + x = layer(x, emb) + elif isinstance(layer, AttentionBlock): + x = layer(x, encoder_out, mask=mask) + else: + x = layer(x) + return x + + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.dims == 3: + x = F.interpolate( + x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" + ) + else: + x = F.interpolate(x, scale_factor=2, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + stride = 2 if dims != 3 else (1, 2, 2) + if use_conv: + self.op = conv_nd( + dims, self.channels, self.out_channels, 3, stride=stride, padding=1 + ) + else: + assert self.channels == self.out_channels + self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.op(x) + + +class ResBlock(TimestepBlock): + """ + A residual block that can optionally change the number of channels. + + :param channels: the number of input channels. + :param emb_channels: the number of timestep embedding channels. + :param dropout: the rate of dropout. + :param out_channels: if specified, the number of out channels. + :param use_conv: if True and out_channels is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the + channels in the skip connection. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param use_checkpoint: if True, use gradient checkpointing on this module. + :param up: if True, use this block for upsampling. + :param down: if True, use this block for downsampling. + """ + + def __init__( + self, + channels, + emb_channels, + dropout, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + dims=2, + use_checkpoint=False, + up=False, + down=False, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_checkpoint = use_checkpoint + self.use_scale_shift_norm = use_scale_shift_norm + + self.in_layers = nn.Sequential( + normalization(channels, swish=1.0), + nn.Identity(), + conv_nd(dims, channels, self.out_channels, 3, padding=1), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.emb_layers = nn.Sequential( + nn.SiLU(), + linear( + emb_channels, + 2 * self.out_channels if use_scale_shift_norm else self.out_channels, + ), + ) + self.out_layers = nn.Sequential( + normalization( + self.out_channels, swish=0.0 if use_scale_shift_norm else 1.0 + ), + nn.SiLU() if use_scale_shift_norm else nn.Identity(), + nn.Dropout(p=dropout), + zero_module( + conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) + ), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd( + dims, channels, self.out_channels, 3, padding=1 + ) + else: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) + + def forward(self, x, emb): + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + + :param x: an [N x C x ...] Tensor of features. + :param emb: an [N x emb_channels] Tensor of timestep embeddings. + :return: an [N x C x ...] Tensor of outputs. + """ + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + emb_out = self.emb_layers(emb) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = th.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + h = h + emb_out + h = self.out_layers(h) + return self.skip_connection(x) + h + + +class ResBlockNoTimeEmbedding(nn.Module): + """ + A residual block without time embedding + + :param channels: the number of input channels. + :param emb_channels: the number of timestep embedding channels. + :param dropout: the rate of dropout. + :param out_channels: if specified, the number of out channels. + :param use_conv: if True and out_channels is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the + channels in the skip connection. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param use_checkpoint: if True, use gradient checkpointing on this module. + :param up: if True, use this block for upsampling. + :param down: if True, use this block for downsampling. + """ + + def __init__( + self, + channels, + emb_channels, + dropout, + out_channels=None, + use_conv=False, + dims=2, + use_checkpoint=False, + up=False, + down=False, + **kwargs, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_checkpoint = use_checkpoint + + self.in_layers = nn.Sequential( + normalization(channels, swish=1.0), + nn.Identity(), + conv_nd(dims, channels, self.out_channels, 3, padding=1), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.out_layers = nn.Sequential( + normalization(self.out_channels, swish=1.0), + nn.Dropout(p=dropout), + zero_module( + conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) + ), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd( + dims, channels, self.out_channels, 3, padding=1 + ) + else: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) + + def forward(self, x, emb=None): + """ + Apply the block to a Tensor, NOT conditioned on a timestep embedding. + + :param x: an [N x C x ...] Tensor of features. + :param emb: an [N x emb_channels] Tensor of timestep embeddings. + :return: an [N x C x ...] Tensor of outputs. + """ + assert emb is None + + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + h = self.out_layers(h) + return self.skip_connection(x) + h + + +class AttentionBlock(nn.Module): + """ + An attention block that allows spatial positions to attend to each other. + + Originally ported from here, but adapted to the N-d case. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. + """ + + def __init__( + self, + channels, + num_heads=1, + num_head_channels=-1, + use_checkpoint=False, + encoder_channels=None, + ): + super().__init__() + self.channels = channels + if num_head_channels == -1: + self.num_heads = num_heads + else: + assert ( + channels % num_head_channels == 0 + ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" + self.num_heads = channels // num_head_channels + self.use_checkpoint = use_checkpoint + self.norm = normalization(channels, swish=0.0) + self.qkv = conv_nd(1, channels, channels * 3, 1) + self.attention = QKVAttention(self.num_heads) + + if encoder_channels is not None: + self.encoder_kv = conv_nd(1, encoder_channels, channels * 2, 1) + self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) + + def forward(self, x, encoder_out=None, mask=None): + b, c, *spatial = x.shape + qkv = self.qkv(self.norm(x).view(b, c, -1)) + if encoder_out is not None: + encoder_out = self.encoder_kv(encoder_out) + h = self.attention(qkv, encoder_out, mask=mask) + else: + h = self.attention(qkv) + h = self.proj_out(h) + return x + h.reshape(b, c, *spatial) + + +class QKVAttention(nn.Module): + """ + A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv, encoder_kv=None, mask=None): + """ + Apply QKV attention. + + :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) + if encoder_kv is not None: + assert encoder_kv.shape[1] == self.n_heads * ch * 2 + ek, ev = encoder_kv.reshape(bs * self.n_heads, ch * 2, -1).split(ch, dim=1) + k = th.cat([ek, k], dim=-1) + v = th.cat([ev, v], dim=-1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum("bct,bcs->bts", q * scale, k * scale) + if mask is not None: + mask = F.pad(mask, (0, length), value=0.0) + mask = ( + mask.unsqueeze(1) + .expand(-1, self.n_heads, -1) + .reshape(bs * self.n_heads, 1, -1) + ) + weight = weight + mask + weight = th.softmax(weight, dim=-1) + a = th.einsum("bts,bcs->bct", weight, v) + return a.reshape(bs, -1, length) + + +class UNetModel(nn.Module): + """ + The full UNet model with attention and timestep embedding. + + :param in_channels: channels in the input Tensor. + :param model_channels: base channel count for the model. + :param out_channels: channels in the output Tensor. + :param num_res_blocks: number of residual blocks per downsample. + :param attention_resolutions: a collection of downsample rates at which + attention will take place. May be a set, list, or tuple. + For example, if this contains 4, then at 4x downsampling, attention + will be used. + :param dropout: the dropout probability. + :param channel_mult: channel multiplier for each level of the UNet. + :param conv_resample: if True, use learned convolutions for upsampling and + downsampling. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param clip_dim: dimension of clip feature. + :param num_classes: if specified (as an int), then this model will be + class-conditional with `num_classes` classes. + :param use_checkpoint: use gradient checkpointing to reduce memory usage. + :param num_heads: the number of attention heads in each attention layer. + :param num_heads_channels: if specified, ignore num_heads and instead use + a fixed channel width per attention head. + :param num_heads_upsample: works with num_heads to set a different number + of heads for upsampling. Deprecated. + :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. + :param resblock_updown: use residual blocks for up/downsampling. + :param encoder_channels: use to make the dimension of query and kv same in AttentionBlock. + :param use_time_embedding: use time embedding for condition. + """ + + def __init__( + self, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + clip_dim=None, + use_checkpoint=False, + num_heads=1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + use_middle_attention=True, + resblock_updown=False, + encoder_channels=None, + use_time_embedding=True, + ): + super().__init__() + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.clip_dim = clip_dim + self.use_checkpoint = use_checkpoint + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + self.use_middle_attention = use_middle_attention + self.use_time_embedding = use_time_embedding + + if self.use_time_embedding: + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + if self.clip_dim is not None: + self.clip_emb = nn.Linear(clip_dim, time_embed_dim) + else: + time_embed_dim = None + + CustomResidualBlock = ( + ResBlock if self.use_time_embedding else ResBlockNoTimeEmbedding + ) + ch = input_ch = int(channel_mult[0] * model_channels) + self.input_blocks = nn.ModuleList( + [TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))] + ) + self._feature_size = ch + input_block_chans = [ch] + ds = 1 + for level, mult in enumerate(channel_mult): + for _ in range(num_res_blocks): + layers = [ + CustomResidualBlock( + ch, + time_embed_dim, + dropout, + out_channels=int(mult * model_channels), + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = int(mult * model_channels) + if ds in attention_resolutions: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + encoder_channels=encoder_channels, + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + CustomResidualBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + self.middle_block = TimestepEmbedSequential( + CustomResidualBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + *( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + encoder_channels=encoder_channels, + ), + ) + if self.use_middle_attention + else tuple(), # add AttentionBlock or not + CustomResidualBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + + self.output_blocks = nn.ModuleList([]) + for level, mult in list(enumerate(channel_mult))[::-1]: + for i in range(num_res_blocks + 1): + ich = input_block_chans.pop() + layers = [ + CustomResidualBlock( + ch + ich, + time_embed_dim, + dropout, + out_channels=int(model_channels * mult), + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = int(model_channels * mult) + if ds in attention_resolutions: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads_upsample, + num_head_channels=num_head_channels, + encoder_channels=encoder_channels, + ) + ) + if level and i == num_res_blocks: + out_ch = ch + layers.append( + CustomResidualBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + up=True, + ) + if resblock_updown + else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + ds //= 2 + self.output_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + + self.out = nn.Sequential( + normalization(ch, swish=1.0), + nn.Identity(), + zero_module(conv_nd(dims, input_ch, out_channels, 3, padding=1)), + ) + + def forward(self, x, timesteps, y=None): + """ + Apply the model to an input batch. + + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :param y: an [N] Tensor of labels, if class-conditional. + :return: an [N x C x ...] Tensor of outputs. + """ + assert (y is not None) == ( + self.clip_dim is not None + ), "must specify y if and only if the model is clip-rep-conditional" + + hs = [] + if self.use_time_embedding: + emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + if self.clip_dim is not None: + emb = emb + self.clip_emb(y) + else: + emb = None + + h = x + for module in self.input_blocks: + h = module(h, emb) + hs.append(h) + h = self.middle_block(h, emb) + for module in self.output_blocks: + h = th.cat([h, hs.pop()], dim=1) + h = module(h, emb) + + return self.out(h) + + +class SuperResUNetModel(UNetModel): + """ + A UNetModel that performs super-resolution. + + Expects an extra kwarg `low_res` to condition on a low-resolution image. + Assumes that the shape of low-resolution and the input should be the same. + """ + + def __init__(self, *args, **kwargs): + if "in_channels" in kwargs: + kwargs = dict(kwargs) + kwargs["in_channels"] = kwargs["in_channels"] * 2 + else: + # Curse you, Python. Or really, just curse positional arguments :|. + args = list(args) + args[1] = args[1] * 2 + super().__init__(*args, **kwargs) + + def forward(self, x, timesteps, low_res=None, **kwargs): + _, _, new_height, new_width = x.shape + assert new_height == low_res.shape[2] and new_width == low_res.shape[3] + + x = th.cat([x, low_res], dim=1) + return super().forward(x, timesteps, **kwargs) + + +class PLMImUNet(UNetModel): + """ + A UNetModel that conditions on text with a pretrained text encoder in CLIP. + + :param text_ctx: number of text tokens to expect. + :param xf_width: width of the transformer. + :param clip_emb_mult: #extra tokens by projecting clip text feature. + :param clip_emb_type: type of condition (here, we fix clip image feature). + :param clip_emb_drop: dropout rato of clip image feature for cfg. + """ + + def __init__( + self, + text_ctx, + xf_width, + *args, + clip_emb_mult=None, + clip_emb_type="image", + clip_emb_drop=0.0, + **kwargs, + ): + self.text_ctx = text_ctx + self.xf_width = xf_width + self.clip_emb_mult = clip_emb_mult + self.clip_emb_type = clip_emb_type + self.clip_emb_drop = clip_emb_drop + + if not xf_width: + super().__init__(*args, **kwargs, encoder_channels=None) + else: + super().__init__(*args, **kwargs, encoder_channels=xf_width) + + # Project text encoded feat seq from pre-trained text encoder in CLIP + self.text_seq_proj = nn.Sequential( + nn.Linear(self.clip_dim, xf_width), + LayerNorm(xf_width), + ) + # Project CLIP text feat + self.text_feat_proj = nn.Linear(self.clip_dim, self.model_channels * 4) + + assert clip_emb_mult is not None + assert clip_emb_type == "image" + assert self.clip_dim is not None, "CLIP representation dim should be specified" + + self.clip_tok_proj = nn.Linear( + self.clip_dim, self.xf_width * self.clip_emb_mult + ) + if self.clip_emb_drop > 0: + self.cf_param = nn.Parameter(th.empty(self.clip_dim, dtype=th.float32)) + + def proc_clip_emb_drop(self, feat): + if self.clip_emb_drop > 0: + bsz, feat_dim = feat.shape + assert ( + feat_dim == self.clip_dim + ), f"CLIP input dim: {feat_dim}, model CLIP dim: {self.clip_dim}" + drop_idx = th.rand((bsz,), device=feat.device) < self.clip_emb_drop + feat = th.where( + drop_idx[..., None], self.cf_param[None].type_as(feat), feat + ) + return feat + + def forward( + self, x, timesteps, txt_feat=None, txt_feat_seq=None, mask=None, y=None + ): + bsz = x.shape[0] + hs = [] + emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + emb = emb + self.clip_emb(y) + + xf_out = self.text_seq_proj(txt_feat_seq) + xf_out = xf_out.permute(0, 2, 1) + emb = emb + self.text_feat_proj(txt_feat) + xf_out = th.cat( + [ + self.clip_tok_proj(y).reshape(bsz, -1, self.clip_emb_mult), + xf_out, + ], + dim=2, + ) + mask = F.pad(mask, (self.clip_emb_mult, 0), value=True) + mask = th.where(mask, 0.0, float("-inf")) + + h = x + for module in self.input_blocks: + h = module(h, emb, xf_out, mask=mask) + hs.append(h) + h = self.middle_block(h, emb, xf_out, mask=mask) + for module in self.output_blocks: + h = th.cat([h, hs.pop()], dim=1) + h = module(h, emb, xf_out, mask=mask) + h = self.out(h) + + return h diff --git a/ldm/modules/karlo/kakao/modules/xf.py b/ldm/modules/karlo/kakao/modules/xf.py new file mode 100644 index 0000000..66d7d4a --- /dev/null +++ b/ldm/modules/karlo/kakao/modules/xf.py @@ -0,0 +1,231 @@ +# ------------------------------------------------------------------------------------ +# Adapted from the repos below: +# (a) Guided-Diffusion (https://github.com/openai/guided-diffusion) +# (b) CLIP ViT (https://github.com/openai/CLIP/) +# ------------------------------------------------------------------------------------ + +import math + +import torch as th +import torch.nn as nn +import torch.nn.functional as F + +from .nn import timestep_embedding + + +def convert_module_to_f16(param): + """ + Convert primitive modules to float16. + """ + if isinstance(param, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)): + param.weight.data = param.weight.data.half() + if param.bias is not None: + param.bias.data = param.bias.data.half() + + +class LayerNorm(nn.LayerNorm): + """ + Implementation that supports fp16 inputs but fp32 gains/biases. + """ + + def forward(self, x: th.Tensor): + return super().forward(x.float()).to(x.dtype) + + +class MultiheadAttention(nn.Module): + def __init__(self, n_ctx, width, heads): + super().__init__() + self.n_ctx = n_ctx + self.width = width + self.heads = heads + self.c_qkv = nn.Linear(width, width * 3) + self.c_proj = nn.Linear(width, width) + self.attention = QKVMultiheadAttention(heads, n_ctx) + + def forward(self, x, mask=None): + x = self.c_qkv(x) + x = self.attention(x, mask=mask) + x = self.c_proj(x) + return x + + +class MLP(nn.Module): + def __init__(self, width): + super().__init__() + self.width = width + self.c_fc = nn.Linear(width, width * 4) + self.c_proj = nn.Linear(width * 4, width) + self.gelu = nn.GELU() + + def forward(self, x): + return self.c_proj(self.gelu(self.c_fc(x))) + + +class QKVMultiheadAttention(nn.Module): + def __init__(self, n_heads: int, n_ctx: int): + super().__init__() + self.n_heads = n_heads + self.n_ctx = n_ctx + + def forward(self, qkv, mask=None): + bs, n_ctx, width = qkv.shape + attn_ch = width // self.n_heads // 3 + scale = 1 / math.sqrt(math.sqrt(attn_ch)) + qkv = qkv.view(bs, n_ctx, self.n_heads, -1) + q, k, v = th.split(qkv, attn_ch, dim=-1) + weight = th.einsum("bthc,bshc->bhts", q * scale, k * scale) + wdtype = weight.dtype + if mask is not None: + weight = weight + mask[:, None, ...] + weight = th.softmax(weight, dim=-1).type(wdtype) + return th.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1) + + +class ResidualAttentionBlock(nn.Module): + def __init__( + self, + n_ctx: int, + width: int, + heads: int, + ): + super().__init__() + + self.attn = MultiheadAttention( + n_ctx, + width, + heads, + ) + self.ln_1 = LayerNorm(width) + self.mlp = MLP(width) + self.ln_2 = LayerNorm(width) + + def forward(self, x, mask=None): + x = x + self.attn(self.ln_1(x), mask=mask) + x = x + self.mlp(self.ln_2(x)) + return x + + +class Transformer(nn.Module): + def __init__( + self, + n_ctx: int, + width: int, + layers: int, + heads: int, + ): + super().__init__() + self.n_ctx = n_ctx + self.width = width + self.layers = layers + self.resblocks = nn.ModuleList( + [ + ResidualAttentionBlock( + n_ctx, + width, + heads, + ) + for _ in range(layers) + ] + ) + + def forward(self, x, mask=None): + for block in self.resblocks: + x = block(x, mask=mask) + return x + + +class PriorTransformer(nn.Module): + """ + A Causal Transformer that conditions on CLIP text embedding, text. + + :param text_ctx: number of text tokens to expect. + :param xf_width: width of the transformer. + :param xf_layers: depth of the transformer. + :param xf_heads: heads in the transformer. + :param xf_final_ln: use a LayerNorm after the output layer. + :param clip_dim: dimension of clip feature. + """ + + def __init__( + self, + text_ctx, + xf_width, + xf_layers, + xf_heads, + xf_final_ln, + clip_dim, + ): + super().__init__() + + self.text_ctx = text_ctx + self.xf_width = xf_width + self.xf_layers = xf_layers + self.xf_heads = xf_heads + self.clip_dim = clip_dim + self.ext_len = 4 + + self.time_embed = nn.Sequential( + nn.Linear(xf_width, xf_width), + nn.SiLU(), + nn.Linear(xf_width, xf_width), + ) + self.text_enc_proj = nn.Linear(clip_dim, xf_width) + self.text_emb_proj = nn.Linear(clip_dim, xf_width) + self.clip_img_proj = nn.Linear(clip_dim, xf_width) + self.out_proj = nn.Linear(xf_width, clip_dim) + self.transformer = Transformer( + text_ctx + self.ext_len, + xf_width, + xf_layers, + xf_heads, + ) + if xf_final_ln: + self.final_ln = LayerNorm(xf_width) + else: + self.final_ln = None + + self.positional_embedding = nn.Parameter( + th.empty(1, text_ctx + self.ext_len, xf_width) + ) + self.prd_emb = nn.Parameter(th.randn((1, 1, xf_width))) + + nn.init.normal_(self.prd_emb, std=0.01) + nn.init.normal_(self.positional_embedding, std=0.01) + + def forward( + self, + x, + timesteps, + text_emb=None, + text_enc=None, + mask=None, + causal_mask=None, + ): + bsz = x.shape[0] + mask = F.pad(mask, (0, self.ext_len), value=True) + + t_emb = self.time_embed(timestep_embedding(timesteps, self.xf_width)) + text_enc = self.text_enc_proj(text_enc) + text_emb = self.text_emb_proj(text_emb) + x = self.clip_img_proj(x) + + input_seq = [ + text_enc, + text_emb[:, None, :], + t_emb[:, None, :], + x[:, None, :], + self.prd_emb.to(x.dtype).expand(bsz, -1, -1), + ] + input = th.cat(input_seq, dim=1) + input = input + self.positional_embedding.to(input.dtype) + + mask = th.where(mask, 0.0, float("-inf")) + mask = (mask[:, None, :] + causal_mask).to(input.dtype) + + out = self.transformer(input, mask=mask) + if self.final_ln is not None: + out = self.final_ln(out) + + out = self.out_proj(out[:, -1]) + + return out diff --git a/ldm/modules/karlo/kakao/sampler.py b/ldm/modules/karlo/kakao/sampler.py new file mode 100644 index 0000000..b56bf2f --- /dev/null +++ b/ldm/modules/karlo/kakao/sampler.py @@ -0,0 +1,272 @@ +# ------------------------------------------------------------------------------------ +# Karlo-v1.0.alpha +# Copyright (c) 2022 KakaoBrain. All Rights Reserved. + +# source: https://github.com/kakaobrain/karlo/blob/3c68a50a16d76b48a15c181d1c5a5e0879a90f85/karlo/sampler/t2i.py#L15 +# ------------------------------------------------------------------------------------ + +from typing import Iterator + +import torch +import torchvision.transforms.functional as TVF +from torchvision.transforms import InterpolationMode + +from .template import BaseSampler, CKPT_PATH + + +class T2ISampler(BaseSampler): + """ + A sampler for text-to-image generation. + :param root_dir: directory for model checkpoints. + :param sampling_type: ["default", "fast"] + """ + + def __init__( + self, + root_dir: str, + sampling_type: str = "default", + ): + super().__init__(root_dir, sampling_type) + + @classmethod + def from_pretrained( + cls, + root_dir: str, + clip_model_path: str, + clip_stat_path: str, + sampling_type: str = "default", + ): + + model = cls( + root_dir=root_dir, + sampling_type=sampling_type, + ) + model.load_clip(clip_model_path) + model.load_prior( + f"{CKPT_PATH['prior']}", + clip_stat_path=clip_stat_path, + prior_config="configs/karlo/prior_1B_vit_l.yaml" + ) + model.load_decoder(f"{CKPT_PATH['decoder']}", decoder_config="configs/karlo/decoder_900M_vit_l.yaml") + model.load_sr_64_256(CKPT_PATH["sr_256"], sr_config="configs/karlo/improved_sr_64_256_1.4B.yaml") + return model + + def preprocess( + self, + prompt: str, + bsz: int, + ): + """Setup prompts & cfg scales""" + prompts_batch = [prompt for _ in range(bsz)] + + prior_cf_scales_batch = [self._prior_cf_scale] * len(prompts_batch) + prior_cf_scales_batch = torch.tensor(prior_cf_scales_batch, device="cuda") + + decoder_cf_scales_batch = [self._decoder_cf_scale] * len(prompts_batch) + decoder_cf_scales_batch = torch.tensor(decoder_cf_scales_batch, device="cuda") + + """ Get CLIP text feature """ + clip_model = self._clip + tokenizer = self._tokenizer + max_txt_length = self._prior.model.text_ctx + + tok, mask = tokenizer.padded_tokens_and_mask(prompts_batch, max_txt_length) + cf_token, cf_mask = tokenizer.padded_tokens_and_mask([""], max_txt_length) + if not (cf_token.shape == tok.shape): + cf_token = cf_token.expand(tok.shape[0], -1) + cf_mask = cf_mask.expand(tok.shape[0], -1) + + tok = torch.cat([tok, cf_token], dim=0) + mask = torch.cat([mask, cf_mask], dim=0) + + tok, mask = tok.to(device="cuda"), mask.to(device="cuda") + txt_feat, txt_feat_seq = clip_model.encode_text(tok) + + return ( + prompts_batch, + prior_cf_scales_batch, + decoder_cf_scales_batch, + txt_feat, + txt_feat_seq, + tok, + mask, + ) + + def __call__( + self, + prompt: str, + bsz: int, + progressive_mode=None, + ) -> Iterator[torch.Tensor]: + assert progressive_mode in ("loop", "stage", "final") + with torch.no_grad(), torch.cuda.amp.autocast(): + ( + prompts_batch, + prior_cf_scales_batch, + decoder_cf_scales_batch, + txt_feat, + txt_feat_seq, + tok, + mask, + ) = self.preprocess( + prompt, + bsz, + ) + + """ Transform CLIP text feature into image feature """ + img_feat = self._prior( + txt_feat, + txt_feat_seq, + mask, + prior_cf_scales_batch, + timestep_respacing=self._prior_sm, + ) + + """ Generate 64x64px images """ + images_64_outputs = self._decoder( + txt_feat, + txt_feat_seq, + tok, + mask, + img_feat, + cf_guidance_scales=decoder_cf_scales_batch, + timestep_respacing=self._decoder_sm, + ) + + images_64 = None + for k, out in enumerate(images_64_outputs): + images_64 = out + if progressive_mode == "loop": + yield torch.clamp(out * 0.5 + 0.5, 0.0, 1.0) + if progressive_mode == "stage": + yield torch.clamp(out * 0.5 + 0.5, 0.0, 1.0) + + images_64 = torch.clamp(images_64, -1, 1) + + """ Upsample 64x64 to 256x256 """ + images_256 = TVF.resize( + images_64, + [256, 256], + interpolation=InterpolationMode.BICUBIC, + antialias=True, + ) + images_256_outputs = self._sr_64_256( + images_256, timestep_respacing=self._sr_sm + ) + + for k, out in enumerate(images_256_outputs): + images_256 = out + if progressive_mode == "loop": + yield torch.clamp(out * 0.5 + 0.5, 0.0, 1.0) + if progressive_mode == "stage": + yield torch.clamp(out * 0.5 + 0.5, 0.0, 1.0) + + yield torch.clamp(images_256 * 0.5 + 0.5, 0.0, 1.0) + + +class PriorSampler(BaseSampler): + """ + A sampler for text-to-image generation, but only the prior. + :param root_dir: directory for model checkpoints. + :param sampling_type: ["default", "fast"] + """ + + def __init__( + self, + root_dir: str, + sampling_type: str = "default", + ): + super().__init__(root_dir, sampling_type) + + @classmethod + def from_pretrained( + cls, + root_dir: str, + clip_model_path: str, + clip_stat_path: str, + sampling_type: str = "default", + ): + model = cls( + root_dir=root_dir, + sampling_type=sampling_type, + ) + model.load_clip(clip_model_path) + model.load_prior( + f"{CKPT_PATH['prior']}", + clip_stat_path=clip_stat_path, + prior_config="configs/karlo/prior_1B_vit_l.yaml" + ) + return model + + def preprocess( + self, + prompt: str, + bsz: int, + ): + """Setup prompts & cfg scales""" + prompts_batch = [prompt for _ in range(bsz)] + + prior_cf_scales_batch = [self._prior_cf_scale] * len(prompts_batch) + prior_cf_scales_batch = torch.tensor(prior_cf_scales_batch, device="cuda") + + decoder_cf_scales_batch = [self._decoder_cf_scale] * len(prompts_batch) + decoder_cf_scales_batch = torch.tensor(decoder_cf_scales_batch, device="cuda") + + """ Get CLIP text feature """ + clip_model = self._clip + tokenizer = self._tokenizer + max_txt_length = self._prior.model.text_ctx + + tok, mask = tokenizer.padded_tokens_and_mask(prompts_batch, max_txt_length) + cf_token, cf_mask = tokenizer.padded_tokens_and_mask([""], max_txt_length) + if not (cf_token.shape == tok.shape): + cf_token = cf_token.expand(tok.shape[0], -1) + cf_mask = cf_mask.expand(tok.shape[0], -1) + + tok = torch.cat([tok, cf_token], dim=0) + mask = torch.cat([mask, cf_mask], dim=0) + + tok, mask = tok.to(device="cuda"), mask.to(device="cuda") + txt_feat, txt_feat_seq = clip_model.encode_text(tok) + + return ( + prompts_batch, + prior_cf_scales_batch, + decoder_cf_scales_batch, + txt_feat, + txt_feat_seq, + tok, + mask, + ) + + def __call__( + self, + prompt: str, + bsz: int, + progressive_mode=None, + ) -> Iterator[torch.Tensor]: + assert progressive_mode in ("loop", "stage", "final") + with torch.no_grad(), torch.cuda.amp.autocast(): + ( + prompts_batch, + prior_cf_scales_batch, + decoder_cf_scales_batch, + txt_feat, + txt_feat_seq, + tok, + mask, + ) = self.preprocess( + prompt, + bsz, + ) + + """ Transform CLIP text feature into image feature """ + img_feat = self._prior( + txt_feat, + txt_feat_seq, + mask, + prior_cf_scales_batch, + timestep_respacing=self._prior_sm, + ) + + yield img_feat diff --git a/ldm/modules/karlo/kakao/template.py b/ldm/modules/karlo/kakao/template.py new file mode 100644 index 0000000..949e80e --- /dev/null +++ b/ldm/modules/karlo/kakao/template.py @@ -0,0 +1,141 @@ +# ------------------------------------------------------------------------------------ +# Karlo-v1.0.alpha +# Copyright (c) 2022 KakaoBrain. All Rights Reserved. +# ------------------------------------------------------------------------------------ + +import os +import logging +import torch + +from omegaconf import OmegaConf + +from ldm.modules.karlo.kakao.models.clip import CustomizedCLIP, CustomizedTokenizer +from ldm.modules.karlo.kakao.models.prior_model import PriorDiffusionModel +from ldm.modules.karlo.kakao.models.decoder_model import Text2ImProgressiveModel +from ldm.modules.karlo.kakao.models.sr_64_256 import ImprovedSupRes64to256ProgressiveModel + + +SAMPLING_CONF = { + "default": { + "prior_sm": "25", + "prior_n_samples": 1, + "prior_cf_scale": 4.0, + "decoder_sm": "50", + "decoder_cf_scale": 8.0, + "sr_sm": "7", + }, + "fast": { + "prior_sm": "25", + "prior_n_samples": 1, + "prior_cf_scale": 4.0, + "decoder_sm": "25", + "decoder_cf_scale": 8.0, + "sr_sm": "7", + }, +} + +CKPT_PATH = { + "prior": "prior-ckpt-step=01000000-of-01000000.ckpt", + "decoder": "decoder-ckpt-step=01000000-of-01000000.ckpt", + "sr_256": "improved-sr-ckpt-step=1.2M.ckpt", +} + + +class BaseSampler: + _PRIOR_CLASS = PriorDiffusionModel + _DECODER_CLASS = Text2ImProgressiveModel + _SR256_CLASS = ImprovedSupRes64to256ProgressiveModel + + def __init__( + self, + root_dir: str, + sampling_type: str = "fast", + ): + self._root_dir = root_dir + + sampling_type = SAMPLING_CONF[sampling_type] + self._prior_sm = sampling_type["prior_sm"] + self._prior_n_samples = sampling_type["prior_n_samples"] + self._prior_cf_scale = sampling_type["prior_cf_scale"] + + assert self._prior_n_samples == 1 + + self._decoder_sm = sampling_type["decoder_sm"] + self._decoder_cf_scale = sampling_type["decoder_cf_scale"] + + self._sr_sm = sampling_type["sr_sm"] + + def __repr__(self): + line = "" + line += f"Prior, sampling method: {self._prior_sm}, cf_scale: {self._prior_cf_scale}\n" + line += f"Decoder, sampling method: {self._decoder_sm}, cf_scale: {self._decoder_cf_scale}\n" + line += f"SR(64->256), sampling method: {self._sr_sm}" + + return line + + def load_clip(self, clip_path: str): + clip = CustomizedCLIP.load_from_checkpoint( + os.path.join(self._root_dir, clip_path) + ) + clip = torch.jit.script(clip) + clip.cuda() + clip.eval() + + self._clip = clip + self._tokenizer = CustomizedTokenizer() + + def load_prior( + self, + ckpt_path: str, + clip_stat_path: str, + prior_config: str = "configs/prior_1B_vit_l.yaml" + ): + logging.info(f"Loading prior: {ckpt_path}") + + config = OmegaConf.load(prior_config) + clip_mean, clip_std = torch.load( + os.path.join(self._root_dir, clip_stat_path), map_location="cpu" + ) + + prior = self._PRIOR_CLASS.load_from_checkpoint( + config, + self._tokenizer, + clip_mean, + clip_std, + os.path.join(self._root_dir, ckpt_path), + strict=True, + ) + prior.cuda() + prior.eval() + logging.info("done.") + + self._prior = prior + + def load_decoder(self, ckpt_path: str, decoder_config: str = "configs/decoder_900M_vit_l.yaml"): + logging.info(f"Loading decoder: {ckpt_path}") + + config = OmegaConf.load(decoder_config) + decoder = self._DECODER_CLASS.load_from_checkpoint( + config, + self._tokenizer, + os.path.join(self._root_dir, ckpt_path), + strict=True, + ) + decoder.cuda() + decoder.eval() + logging.info("done.") + + self._decoder = decoder + + def load_sr_64_256(self, ckpt_path: str, sr_config: str = "configs/improved_sr_64_256_1.4B.yaml"): + logging.info(f"Loading SR(64->256): {ckpt_path}") + + config = OmegaConf.load(sr_config) + sr = self._SR256_CLASS.load_from_checkpoint( + config, os.path.join(self._root_dir, ckpt_path), strict=True + ) + sr.cuda() + sr.eval() + logging.info("done.") + + self._sr_64_256 = sr \ No newline at end of file diff --git a/scripts/streamlit/stablekarlo.py b/scripts/streamlit/stablekarlo.py new file mode 100644 index 0000000..e57500c --- /dev/null +++ b/scripts/streamlit/stablekarlo.py @@ -0,0 +1,381 @@ +import importlib +import streamlit as st +import torch +import cv2 +import numpy as np +import PIL +from omegaconf import OmegaConf +from PIL import Image +from tqdm import trange +import io, os +from torch import autocast +from einops import rearrange, repeat +from torchvision.utils import make_grid +from pytorch_lightning import seed_everything +from contextlib import nullcontext + +from ldm.models.diffusion.ddim import DDIMSampler +from ldm.models.diffusion.plms import PLMSSampler +from ldm.models.diffusion.dpm_solver import DPMSolverSampler + + +torch.set_grad_enabled(False) + +PROMPTS_ROOT = "scripts/prompts/" +SAVE_PATH = "outputs/demo/stable-karlo/" + +VERSION2SPECS = { + "Stable Karlo": {"H": 768, "W": 768, "C": 4, "f": 8}, + "Full Karlo": {} +} + + +def get_obj_from_str(string, reload=False): + module, cls = string.rsplit(".", 1) + importlib.invalidate_caches() + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) + + +def instantiate_from_config(config): + if not "target" in config: + raise KeyError("Expected key `target` to instantiate.") + return get_obj_from_str(config["target"])(**config.get("params", dict())) + + +def get_interactive_image(): + image = st.file_uploader("Input", type=["jpg", "JPEG", "png"]) + if image is not None: + image = Image.open(image) + if not image.mode == "RGB": + image = image.convert("RGB") + return image + + +def load_img(display=True): + image = get_interactive_image() + if display: + st.image(image) + w, h = image.size + print(f"loaded input image of size ({w}, {h})") + w, h = map(lambda x: x - x % 64, (w, h)) + image = image.resize((w, h), resample=PIL.Image.LANCZOS) + image = np.array(image).astype(np.float32) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + return 2. * image - 1. + + +def get_init_img(batch_size=1): + init_image = load_img().cuda() + init_image = repeat(init_image, '1 ... -> b ...', b=batch_size) + return init_image + + +def sample( + model, + prompt, + n_runs=3, + n_samples=2, + H=512, + W=512, + C=4, + f=8, + scale=10.0, + ddim_steps=50, + ddim_eta=0.0, + callback=None, + skip_single_save=False, + save_grid=True, + ucg_schedule=None, + negative_prompt="", + adm_cond=None, + adm_uc=None, + use_full_precision=False, + only_adm_cond=False +): + batch_size = n_samples + precision_scope = autocast if not use_full_precision else nullcontext + if use_full_precision: st.warning(f"Sampling {model.__class__.__name__} at full precision.") + if isinstance(prompt, str): + prompt = [prompt] + prompts = batch_size * prompt + + outputs = st.empty() + + with precision_scope("cuda"): + with model.ema_scope(): + all_samples = list() + for n in trange(n_runs, desc="Sampling"): + shape = [C, H // f, W // f] + if not only_adm_cond: + uc = None + if scale != 1.0: + uc = model.get_learned_conditioning(batch_size * [negative_prompt]) + if isinstance(prompts, tuple): + prompts = list(prompts) + c = model.get_learned_conditioning(prompts) + if isinstance(model, Txt2ImgDiffusionWithPooledInput): + c, uc = c[0], uc[0] + + if adm_cond is not None: + if adm_cond.shape[0] == 1: + adm_cond = repeat(adm_cond, '1 ... -> b ...', b=batch_size) + if adm_uc is None: + st.warning("Not guiding via c_adm") + adm_uc = adm_cond + else: + if adm_uc.shape[0] == 1: + adm_uc = repeat(adm_uc, '1 ... -> b ...', b=batch_size) + if not only_adm_cond: + c = {"c_crossattn": [c], "c_adm": adm_cond} + uc = {"c_crossattn": [uc], "c_adm": adm_uc} + else: + c = adm_cond + uc = adm_uc + samples_ddim, _ = sampler.sample(S=ddim_steps, + conditioning=c, + batch_size=batch_size, + shape=shape, + verbose=False, + unconditional_guidance_scale=scale, + unconditional_conditioning=uc, + eta=ddim_eta, + x_T=None, + callback=callback, + ucg_schedule=ucg_schedule + ) + + x_samples = model.decode_first_stage(samples_ddim) + x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) + + if not skip_single_save: + base_count = len(os.listdir(os.path.join(SAVE_PATH, "samples"))) + for x_sample in x_samples: + x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') + Image.fromarray(x_sample.astype(np.uint8)).save( + os.path.join(SAVE_PATH, "samples", f"{base_count:09}.png")) + base_count += 1 + + all_samples.append(x_samples) + + # get grid of all samples + grid = torch.stack(all_samples, 0) + grid = rearrange(grid, 'n b c h w -> (n h) (b w) c') + outputs.image(grid.cpu().numpy()) + + # additionally, save grid + grid = Image.fromarray((255. * grid.cpu().numpy()).astype(np.uint8)) + if save_grid: + grid_count = len(os.listdir(SAVE_PATH)) - 1 + grid.save(os.path.join(SAVE_PATH, f'grid-{grid_count:06}.png')) + + return x_samples + + +def make_oscillating_guidance_schedule(num_steps, max_weight=15., min_weight=1.): + schedule = list() + for i in range(num_steps): + if float(i / num_steps) < 0.1: + schedule.append(max_weight) + elif i % 2 == 0: + schedule.append(min_weight) + else: + schedule.append(max_weight) + print(f"OSCILLATING GUIDANCE SCHEDULE: \n {schedule}") + return schedule + + +def torch2np(x): + x = ((x + 1.0) * 127.5).clamp(0, 255).to(dtype=torch.uint8) + x = x.permute(0, 2, 3, 1).detach().cpu().numpy() + return x + + +@st.cache(allow_output_mutation=True, suppress_st_warning=True) +def init(version="Stable Karlo", load_karlo_prior=False): + state = dict() + if not "model" in state: + if version == "Stable Karlo": + config = "configs/stable-diffusion/v2-1-stable-karlo-inference.yaml" + ckpt = "checkpoints/v2-1-stable-unclip-ft.ckpt" + + elif version == "Full Karlo": + from ldm.modules.karlo.kakao.sampler import T2ISampler + st.info("Loading full KARLO..") + karlo = T2ISampler.from_pretrained( + root_dir="checkpoints/karlo_models", + clip_model_path="ViT-L-14.pt", + clip_stat_path="ViT-L-14_stats.th", + sampling_type="default", + ) + state["karlo_prior"] = karlo + state["msg"] = "loaded full Karlo" + return state + else: + raise ValueError(f"version {version} unknown!") + + config = OmegaConf.load(config) + model, msg = load_model_from_config(config, ckpt, vae_sd=None) + state["msg"] = msg + + if load_karlo_prior: + from ldm.modules.karlo.kakao.sampler import PriorSampler + st.info("Loading KARLO CLIP prior...") + karlo_prior = PriorSampler.from_pretrained( + root_dir="/fsx/robin/checkpoints/karlo_models", + clip_model_path="ViT-L-14.pt", + clip_stat_path="ViT-L-14_stats.th", + sampling_type="default", + ) + state["karlo_prior"] = karlo_prior + state["model"] = model + state["ckpt"] = ckpt + state["config"] = config + return state + + +def load_model_from_config(config, ckpt, verbose=False, vae_sd=None): + print(f"Loading model from {ckpt}") + pl_sd = torch.load(ckpt, map_location="cpu") + msg = None + if "global_step" in pl_sd: + msg = f"This is global step {pl_sd['global_step']}. " + if "model_ema.num_updates" in pl_sd["state_dict"]: + msg += f"And we got {pl_sd['state_dict']['model_ema.num_updates']} EMA updates." + global_step = pl_sd.get("global_step", "?") + sd = pl_sd["state_dict"] + if vae_sd is not None: + for k in sd.keys(): + if "first_stage" in k: + sd[k] = vae_sd[k[len("first_stage_model."):]] + + model = instantiate_from_config(config.model) + m, u = model.load_state_dict(sd, strict=False) + if len(m) > 0 and verbose: + print("missing keys:") + print(m) + if len(u) > 0 and verbose: + print("unexpected keys:") + print(u) + + model.cuda() + model.eval() + print(f"Loaded global step {global_step}") + return model, msg + + +if __name__ == "__main__": + st.title("Stable Karlo") + mode = "txt2img" + version = st.selectbox("Model Version", list(VERSION2SPECS.keys()), 0) + use_karlo = st.checkbox("Use KARLO prior", False) + state = init(version=version, vae_version=vae_version, load_karlo_prior=use_karlo) + st.info(state["msg"]) + prompt = st.text_input("Prompt", "a professional photograph of an astronaut riding a horse") + negative_prompt = st.text_input("Negative Prompt", "") + scale = st.number_input("cfg-scale", value=10., min_value=-100., max_value=100.) + number_rows = st.number_input("num rows", value=2, min_value=1, max_value=10) + number_cols = st.number_input("num cols", value=2, min_value=1, max_value=10) + default_steps = 25 + steps = st.sidebar.number_input("steps", value=default_steps, min_value=1, max_value=1000) + eta = st.sidebar.number_input("eta (DDIM)", value=0., min_value=0., max_value=1.) + force_full_precision = st.sidebar.checkbox("Force FP32", False) + if version != "Full Karlo": + H = st.sidebar.number_input("H", value=VERSION2SPECS[version]["H"], min_value=64, max_value=2048) + W = st.sidebar.number_input("W", value=VERSION2SPECS[version]["W"], min_value=64, max_value=2048) + C = VERSION2SPECS[version]["C"] + f = VERSION2SPECS[version]["f"] + + SAVE_PATH = os.path.join(SAVE_PATH, version + "_" + vae_version + "-decoder") + os.makedirs(os.path.join(SAVE_PATH, "samples"), exist_ok=True) + + seed = st.sidebar.number_input("seed", value=42, min_value=0, max_value=int(1e9)) + seed_everything(seed) + + ucg_schedule = None + sampler = st.sidebar.selectbox("Sampler", ["DDIM", "PLMS", "DPM"], 2) + if version == "Full Karlo": + pass + else: + if sampler == "PLMS": + st.warning("NOTE: Some models (such as v-pred) currently only support DDIM/DPM sampling here") + sampler = PLMSSampler(state["model"]) + elif sampler == "DPM": + st.warning("NOTE: Using DPM sampler with default sampling parameters (DPM-2)") + sampler = DPMSolverSampler(state["model"]) + elif sampler == "DDIM": + sampler = DDIMSampler(state["model"]) + if st.checkbox("Try oscillating guidance?", False): + ucg_schedule = make_oscillating_guidance_schedule(num_steps=steps, max_weight=scale, min_weight=1.) + else: + raise ValueError(f"unknown sampler {sampler}!") + + adm_cond, adm_uc = None, None + if use_karlo: + # uses the prior + karlo_sampler = state["karlo_prior"] + with torch.no_grad(): + karlo_prediction = iter( + karlo_sampler( + prompt=prompt, + bsz=number_cols, + progressive_mode="final", + ) + ).__next__() + adm_cond = karlo_prediction + adm_uc = torch.zeros_like(karlo_prediction) + + else: + init_img = get_init_img(batch_size=number_cols) + with torch.no_grad(): + adm_cond = state["model"].embedder(init_img) + adm_uc = torch.zeros_like(adm_cond) + + if st.button("Sample"): + print("running prompt:", prompt) + st.text("Sampling") + t_progress = st.progress(0) + result = st.empty() + + def t_callback(t): + t_progress.progress(min((t + 1) / steps, 1.)) + + if version == "KARLO": + outputs = st.empty() + karlo_sampler = state["karlo_prior"] + all_samples = list() + with torch.no_grad(): + for _ in range(number_rows): + karlo_prediction = iter( + karlo_sampler( + prompt=prompt, + bsz=number_cols, + progressive_mode="final", + ) + ).__next__() + all_samples.append(karlo_prediction) + grid = torch.stack(all_samples, 0) + grid = rearrange(grid, 'n b c h w -> (n h) (b w) c') + outputs.image(grid.cpu().numpy()) + + else: + samples = sample( + state["model"], + prompt, + n_runs=number_rows, + n_samples=number_cols, + H=H, W=W, C=C, f=f, + scale=scale, + ddim_steps=steps, + ddim_eta=eta, + callback=t_callback, + ucg_schedule=ucg_schedule, + negative_prompt=negative_prompt, + adm_cond=adm_cond, adm_uc=adm_uc, + use_full_precision=force_full_precision, + only_adm_cond=False + ) + \ No newline at end of file