mirror of
https://github.com/Stability-AI/stablediffusion.git
synced 2024-12-23 08:04:59 +00:00
141 lines
No EOL
4.2 KiB
Python
141 lines
No EOL
4.2 KiB
Python
# ------------------------------------------------------------------------------------
|
|
# Karlo-v1.0.alpha
|
|
# Copyright (c) 2022 KakaoBrain. All Rights Reserved.
|
|
# ------------------------------------------------------------------------------------
|
|
|
|
import os
|
|
import logging
|
|
import torch
|
|
|
|
from omegaconf import OmegaConf
|
|
|
|
from ldm.modules.karlo.kakao.models.clip import CustomizedCLIP, CustomizedTokenizer
|
|
from ldm.modules.karlo.kakao.models.prior_model import PriorDiffusionModel
|
|
from ldm.modules.karlo.kakao.models.decoder_model import Text2ImProgressiveModel
|
|
from ldm.modules.karlo.kakao.models.sr_64_256 import ImprovedSupRes64to256ProgressiveModel
|
|
|
|
|
|
SAMPLING_CONF = {
|
|
"default": {
|
|
"prior_sm": "25",
|
|
"prior_n_samples": 1,
|
|
"prior_cf_scale": 4.0,
|
|
"decoder_sm": "50",
|
|
"decoder_cf_scale": 8.0,
|
|
"sr_sm": "7",
|
|
},
|
|
"fast": {
|
|
"prior_sm": "25",
|
|
"prior_n_samples": 1,
|
|
"prior_cf_scale": 4.0,
|
|
"decoder_sm": "25",
|
|
"decoder_cf_scale": 8.0,
|
|
"sr_sm": "7",
|
|
},
|
|
}
|
|
|
|
CKPT_PATH = {
|
|
"prior": "prior-ckpt-step=01000000-of-01000000.ckpt",
|
|
"decoder": "decoder-ckpt-step=01000000-of-01000000.ckpt",
|
|
"sr_256": "improved-sr-ckpt-step=1.2M.ckpt",
|
|
}
|
|
|
|
|
|
class BaseSampler:
|
|
_PRIOR_CLASS = PriorDiffusionModel
|
|
_DECODER_CLASS = Text2ImProgressiveModel
|
|
_SR256_CLASS = ImprovedSupRes64to256ProgressiveModel
|
|
|
|
def __init__(
|
|
self,
|
|
root_dir: str,
|
|
sampling_type: str = "fast",
|
|
):
|
|
self._root_dir = root_dir
|
|
|
|
sampling_type = SAMPLING_CONF[sampling_type]
|
|
self._prior_sm = sampling_type["prior_sm"]
|
|
self._prior_n_samples = sampling_type["prior_n_samples"]
|
|
self._prior_cf_scale = sampling_type["prior_cf_scale"]
|
|
|
|
assert self._prior_n_samples == 1
|
|
|
|
self._decoder_sm = sampling_type["decoder_sm"]
|
|
self._decoder_cf_scale = sampling_type["decoder_cf_scale"]
|
|
|
|
self._sr_sm = sampling_type["sr_sm"]
|
|
|
|
def __repr__(self):
|
|
line = ""
|
|
line += f"Prior, sampling method: {self._prior_sm}, cf_scale: {self._prior_cf_scale}\n"
|
|
line += f"Decoder, sampling method: {self._decoder_sm}, cf_scale: {self._decoder_cf_scale}\n"
|
|
line += f"SR(64->256), sampling method: {self._sr_sm}"
|
|
|
|
return line
|
|
|
|
def load_clip(self, clip_path: str):
|
|
clip = CustomizedCLIP.load_from_checkpoint(
|
|
os.path.join(self._root_dir, clip_path)
|
|
)
|
|
clip = torch.jit.script(clip)
|
|
clip.cuda()
|
|
clip.eval()
|
|
|
|
self._clip = clip
|
|
self._tokenizer = CustomizedTokenizer()
|
|
|
|
def load_prior(
|
|
self,
|
|
ckpt_path: str,
|
|
clip_stat_path: str,
|
|
prior_config: str = "configs/prior_1B_vit_l.yaml"
|
|
):
|
|
logging.info(f"Loading prior: {ckpt_path}")
|
|
|
|
config = OmegaConf.load(prior_config)
|
|
clip_mean, clip_std = torch.load(
|
|
os.path.join(self._root_dir, clip_stat_path), map_location="cpu"
|
|
)
|
|
|
|
prior = self._PRIOR_CLASS.load_from_checkpoint(
|
|
config,
|
|
self._tokenizer,
|
|
clip_mean,
|
|
clip_std,
|
|
os.path.join(self._root_dir, ckpt_path),
|
|
strict=True,
|
|
)
|
|
prior.cuda()
|
|
prior.eval()
|
|
logging.info("done.")
|
|
|
|
self._prior = prior
|
|
|
|
def load_decoder(self, ckpt_path: str, decoder_config: str = "configs/decoder_900M_vit_l.yaml"):
|
|
logging.info(f"Loading decoder: {ckpt_path}")
|
|
|
|
config = OmegaConf.load(decoder_config)
|
|
decoder = self._DECODER_CLASS.load_from_checkpoint(
|
|
config,
|
|
self._tokenizer,
|
|
os.path.join(self._root_dir, ckpt_path),
|
|
strict=True,
|
|
)
|
|
decoder.cuda()
|
|
decoder.eval()
|
|
logging.info("done.")
|
|
|
|
self._decoder = decoder
|
|
|
|
def load_sr_64_256(self, ckpt_path: str, sr_config: str = "configs/improved_sr_64_256_1.4B.yaml"):
|
|
logging.info(f"Loading SR(64->256): {ckpt_path}")
|
|
|
|
config = OmegaConf.load(sr_config)
|
|
sr = self._SR256_CLASS.load_from_checkpoint(
|
|
config, os.path.join(self._root_dir, ckpt_path), strict=True
|
|
)
|
|
sr.cuda()
|
|
sr.eval()
|
|
logging.info("done.")
|
|
|
|
self._sr_64_256 = sr |