StableDiffusion/ldm/modules/karlo/kakao/template.py
2023-03-24 11:18:44 +01:00

141 lines
No EOL
4.2 KiB
Python

# ------------------------------------------------------------------------------------
# Karlo-v1.0.alpha
# Copyright (c) 2022 KakaoBrain. All Rights Reserved.
# ------------------------------------------------------------------------------------
import os
import logging
import torch
from omegaconf import OmegaConf
from ldm.modules.karlo.kakao.models.clip import CustomizedCLIP, CustomizedTokenizer
from ldm.modules.karlo.kakao.models.prior_model import PriorDiffusionModel
from ldm.modules.karlo.kakao.models.decoder_model import Text2ImProgressiveModel
from ldm.modules.karlo.kakao.models.sr_64_256 import ImprovedSupRes64to256ProgressiveModel
SAMPLING_CONF = {
"default": {
"prior_sm": "25",
"prior_n_samples": 1,
"prior_cf_scale": 4.0,
"decoder_sm": "50",
"decoder_cf_scale": 8.0,
"sr_sm": "7",
},
"fast": {
"prior_sm": "25",
"prior_n_samples": 1,
"prior_cf_scale": 4.0,
"decoder_sm": "25",
"decoder_cf_scale": 8.0,
"sr_sm": "7",
},
}
CKPT_PATH = {
"prior": "prior-ckpt-step=01000000-of-01000000.ckpt",
"decoder": "decoder-ckpt-step=01000000-of-01000000.ckpt",
"sr_256": "improved-sr-ckpt-step=1.2M.ckpt",
}
class BaseSampler:
_PRIOR_CLASS = PriorDiffusionModel
_DECODER_CLASS = Text2ImProgressiveModel
_SR256_CLASS = ImprovedSupRes64to256ProgressiveModel
def __init__(
self,
root_dir: str,
sampling_type: str = "fast",
):
self._root_dir = root_dir
sampling_type = SAMPLING_CONF[sampling_type]
self._prior_sm = sampling_type["prior_sm"]
self._prior_n_samples = sampling_type["prior_n_samples"]
self._prior_cf_scale = sampling_type["prior_cf_scale"]
assert self._prior_n_samples == 1
self._decoder_sm = sampling_type["decoder_sm"]
self._decoder_cf_scale = sampling_type["decoder_cf_scale"]
self._sr_sm = sampling_type["sr_sm"]
def __repr__(self):
line = ""
line += f"Prior, sampling method: {self._prior_sm}, cf_scale: {self._prior_cf_scale}\n"
line += f"Decoder, sampling method: {self._decoder_sm}, cf_scale: {self._decoder_cf_scale}\n"
line += f"SR(64->256), sampling method: {self._sr_sm}"
return line
def load_clip(self, clip_path: str):
clip = CustomizedCLIP.load_from_checkpoint(
os.path.join(self._root_dir, clip_path)
)
clip = torch.jit.script(clip)
clip.cuda()
clip.eval()
self._clip = clip
self._tokenizer = CustomizedTokenizer()
def load_prior(
self,
ckpt_path: str,
clip_stat_path: str,
prior_config: str = "configs/prior_1B_vit_l.yaml"
):
logging.info(f"Loading prior: {ckpt_path}")
config = OmegaConf.load(prior_config)
clip_mean, clip_std = torch.load(
os.path.join(self._root_dir, clip_stat_path), map_location="cpu"
)
prior = self._PRIOR_CLASS.load_from_checkpoint(
config,
self._tokenizer,
clip_mean,
clip_std,
os.path.join(self._root_dir, ckpt_path),
strict=True,
)
prior.cuda()
prior.eval()
logging.info("done.")
self._prior = prior
def load_decoder(self, ckpt_path: str, decoder_config: str = "configs/decoder_900M_vit_l.yaml"):
logging.info(f"Loading decoder: {ckpt_path}")
config = OmegaConf.load(decoder_config)
decoder = self._DECODER_CLASS.load_from_checkpoint(
config,
self._tokenizer,
os.path.join(self._root_dir, ckpt_path),
strict=True,
)
decoder.cuda()
decoder.eval()
logging.info("done.")
self._decoder = decoder
def load_sr_64_256(self, ckpt_path: str, sr_config: str = "configs/improved_sr_64_256_1.4B.yaml"):
logging.info(f"Loading SR(64->256): {ckpt_path}")
config = OmegaConf.load(sr_config)
sr = self._SR256_CLASS.load_from_checkpoint(
config, os.path.join(self._root_dir, ckpt_path), strict=True
)
sr.cuda()
sr.eval()
logging.info("done.")
self._sr_64_256 = sr