Support gradio and streamlit

This commit is contained in:
Ftps 2023-04-02 21:00:29 +09:00
parent d6933311e7
commit 90d4c71350
13 changed files with 154 additions and 68 deletions

View file

@ -1 +0,0 @@
# Only Import opt

View file

@ -6,10 +6,17 @@ from torch.utils.checkpoint import checkpoint
from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel
import open_clip
from ldm import global_opt as g
from ldm.util import default, count_params, autocast
def get_device():
if torch.cuda.is_available():
return 'cuda'
elif torch.backends.mps.is_available():
return 'mps'
else:
return 'cpu'
class AbstractEncoder(nn.Module):
def __init__(self):
@ -61,9 +68,9 @@ def disabled_train(self, mode=True):
class FrozenT5Embedder(AbstractEncoder):
"""Uses the T5 transformer encoder for text"""
def __init__(self, version="google/t5-v1_1-large", device=g.opt.device, max_length=77, freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
def __init__(self, version="google/t5-v1_1-large", device=get_device(), max_length=77, freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
super().__init__()
self.tokenizer = T5Tokenizer.from_pretrained(version)
self.tokenizer = T5Tokenizer.from_pretrained(version, device)
self.transformer = T5EncoderModel.from_pretrained(version)
self.device = device
self.max_length = max_length # TODO: typical value?
@ -98,7 +105,7 @@ class FrozenCLIPEmbedder(AbstractEncoder):
"hidden"
]
def __init__(self, version="openai/clip-vit-large-patch14", device=g.opt.device, max_length=77,
def __init__(self, version="openai/clip-vit-large-patch14", device=get_device(), max_length=77,
freeze=True, layer="last", layer_idx=None): # clip-vit-base-patch32
super().__init__()
assert layer in self.LAYERS
@ -142,7 +149,7 @@ class ClipImageEmbedder(nn.Module):
self,
model,
jit=False,
device='cuda' if torch.cuda.is_available() else 'cpu',
device=get_device(),
antialias=True,
ucg_rate=0.
):
@ -185,7 +192,7 @@ class FrozenOpenCLIPEmbedder(AbstractEncoder):
"penultimate"
]
def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device=g.opt.device, max_length=77,
def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device=get_device(), max_length=77,
freeze=True, layer="last"):
super().__init__()
assert layer in self.LAYERS
@ -243,7 +250,7 @@ class FrozenOpenCLIPImageEmbedder(AbstractEncoder):
Uses the OpenCLIP vision transformer encoder for images
"""
def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77,
def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device=get_device(), max_length=77,
freeze=True, layer="pooled", antialias=True, ucg_rate=0.):
super().__init__()
model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'),
@ -298,7 +305,7 @@ class FrozenOpenCLIPImageEmbedder(AbstractEncoder):
class FrozenCLIPT5Encoder(AbstractEncoder):
def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device=g.opt.device,
def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device=get_device(),
clip_max_length=77, t5_max_length=77):
super().__init__()
self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length)

View file

@ -11,6 +11,7 @@ import torch
import torchvision.transforms.functional as TVF
from torchvision.transforms import InterpolationMode
from contextlib import nullcontext
from .template import BaseSampler, CKPT_PATH
@ -31,6 +32,7 @@ class T2ISampler(BaseSampler):
@classmethod
def from_pretrained(
cls,
device,
root_dir: str,
clip_model_path: str,
clip_stat_path: str,
@ -41,7 +43,7 @@ class T2ISampler(BaseSampler):
root_dir=root_dir,
sampling_type=sampling_type,
)
model.load_clip(clip_model_path)
model.load_clip(clip_model_path, device)
model.load_prior(
f"{CKPT_PATH['prior']}",
clip_stat_path=clip_stat_path,
@ -60,10 +62,10 @@ class T2ISampler(BaseSampler):
prompts_batch = [prompt for _ in range(bsz)]
prior_cf_scales_batch = [self._prior_cf_scale] * len(prompts_batch)
prior_cf_scales_batch = torch.tensor(prior_cf_scales_batch, device="cuda")
prior_cf_scales_batch = torch.tensor(prior_cf_scales_batch, device=self.device)
decoder_cf_scales_batch = [self._decoder_cf_scale] * len(prompts_batch)
decoder_cf_scales_batch = torch.tensor(decoder_cf_scales_batch, device="cuda")
decoder_cf_scales_batch = torch.tensor(decoder_cf_scales_batch, device=self.device)
""" Get CLIP text feature """
clip_model = self._clip
@ -79,7 +81,7 @@ class T2ISampler(BaseSampler):
tok = torch.cat([tok, cf_token], dim=0)
mask = torch.cat([mask, cf_mask], dim=0)
tok, mask = tok.to(device="cuda"), mask.to(device="cuda")
tok, mask = tok.to(device=self.device), mask.to(device=self.device)
txt_feat, txt_feat_seq = clip_model.encode_text(tok)
return (
@ -99,7 +101,8 @@ class T2ISampler(BaseSampler):
progressive_mode=None,
) -> Iterator[torch.Tensor]:
assert progressive_mode in ("loop", "stage", "final")
with torch.no_grad(), torch.cuda.amp.autocast():
precision_scope = nullcontext if self.device.type == 'mps' else torch.cuda.amp.autocast
with torch.no_grad(), precision_scope(self.device.type):
(
prompts_batch,
prior_cf_scales_batch,
@ -181,6 +184,7 @@ class PriorSampler(BaseSampler):
@classmethod
def from_pretrained(
cls,
device,
root_dir: str,
clip_model_path: str,
clip_stat_path: str,
@ -190,7 +194,7 @@ class PriorSampler(BaseSampler):
root_dir=root_dir,
sampling_type=sampling_type,
)
model.load_clip(clip_model_path)
model.load_clip(clip_model_path, device)
model.load_prior(
f"{CKPT_PATH['prior']}",
clip_stat_path=clip_stat_path,
@ -207,10 +211,10 @@ class PriorSampler(BaseSampler):
prompts_batch = [prompt for _ in range(bsz)]
prior_cf_scales_batch = [self._prior_cf_scale] * len(prompts_batch)
prior_cf_scales_batch = torch.tensor(prior_cf_scales_batch, device="cuda")
prior_cf_scales_batch = torch.tensor(prior_cf_scales_batch, device=self.device)
decoder_cf_scales_batch = [self._decoder_cf_scale] * len(prompts_batch)
decoder_cf_scales_batch = torch.tensor(decoder_cf_scales_batch, device="cuda")
decoder_cf_scales_batch = torch.tensor(decoder_cf_scales_batch, device=self.device)
""" Get CLIP text feature """
clip_model = self._clip
@ -226,7 +230,7 @@ class PriorSampler(BaseSampler):
tok = torch.cat([tok, cf_token], dim=0)
mask = torch.cat([mask, cf_mask], dim=0)
tok, mask = tok.to(device="cuda"), mask.to(device="cuda")
tok, mask = tok.to(device=self.device), mask.to(device=self.device)
txt_feat, txt_feat_seq = clip_model.encode_text(tok)
return (
@ -246,7 +250,8 @@ class PriorSampler(BaseSampler):
progressive_mode=None,
) -> Iterator[torch.Tensor]:
assert progressive_mode in ("loop", "stage", "final")
with torch.no_grad(), torch.cuda.amp.autocast():
precision_scope = nullcontext if self.device.type == 'mps' else torch.autocast
with torch.no_grad(), precision_scope(self.device.type):
(
prompts_batch,
prior_cf_scales_batch,

View file

@ -73,15 +73,16 @@ class BaseSampler:
return line
def load_clip(self, clip_path: str):
def load_clip(self, clip_path: str, device):
clip = CustomizedCLIP.load_from_checkpoint(
os.path.join(self._root_dir, clip_path)
)
clip = torch.jit.script(clip)
clip.cuda()
clip.to(device)
clip.eval()
self._clip = clip
self.device = device
self._tokenizer = CustomizedTokenizer()
def load_prior(
@ -105,7 +106,7 @@ class BaseSampler:
os.path.join(self._root_dir, ckpt_path),
strict=True,
)
prior.cuda()
prior.to(self.device)
prior.eval()
logging.info("done.")
@ -121,7 +122,7 @@ class BaseSampler:
os.path.join(self._root_dir, ckpt_path),
strict=True,
)
decoder.cuda()
decoder.to(self.device)
decoder.eval()
logging.info("done.")
@ -134,7 +135,7 @@ class BaseSampler:
sr = self._SR256_CLASS.load_from_checkpoint(
config, os.path.join(self._root_dir, ckpt_path), strict=True
)
sr.cuda()
sr.to(self.device)
sr.eval()
logging.info("done.")

View file

@ -6,6 +6,7 @@ from PIL import Image
from omegaconf import OmegaConf
from einops import repeat, rearrange
from pytorch_lightning import seed_everything
from contextlib import nullcontext
from imwatermark import WatermarkEncoder
from scripts.txt2img import put_watermark
@ -16,15 +17,23 @@ from ldm.data.util import AddMiDaS
torch.set_grad_enabled(False)
def get_device():
if torch.cuda.is_available():
return 'cuda'
elif torch.backends.mps.is_available():
return 'mps'
else:
return 'cpu'
def initialize_model(config, ckpt):
config = OmegaConf.load(config)
model = instantiate_from_config(config.model)
model.load_state_dict(torch.load(ckpt)["state_dict"], strict=False)
device = torch.device(
"cuda") if torch.cuda.is_available() else torch.device("cpu")
device = torch.device(get_device())
model = model.to(device)
sampler = DDIMSampler(model)
sampler = DDIMSampler(model, device)
return sampler
@ -54,8 +63,7 @@ def make_batch_sd(
def paint(sampler, image, prompt, t_enc, seed, scale, num_samples=1, callback=None,
do_full_sample=False):
device = torch.device(
"cuda") if torch.cuda.is_available() else torch.device("cpu")
device = torch.device(get_device())
model = sampler.model
seed_everything(seed)
@ -64,8 +72,9 @@ def paint(sampler, image, prompt, t_enc, seed, scale, num_samples=1, callback=No
wm_encoder = WatermarkEncoder()
wm_encoder.set_watermark('bytes', wm.encode('utf-8'))
precision_scope = nullcontext if device.type == 'mps' else torch.autocast
with torch.no_grad(),\
torch.autocast("cuda"):
precision_scope(device.type):
batch = make_batch_sd(
image, txt=prompt, device=device, num_samples=num_samples)
z = model.get_first_stage_encoding(model.encode_first_stage(

View file

@ -6,6 +6,7 @@ import gradio as gr
from PIL import Image
from omegaconf import OmegaConf
from einops import repeat
from contextlib import nullcontext
from imwatermark import WatermarkEncoder
from pathlib import Path
@ -16,6 +17,16 @@ from ldm.util import instantiate_from_config
torch.set_grad_enabled(False)
def get_device():
if torch.cuda.is_available():
return 'cuda'
elif torch.backends.mps.is_available():
return 'mps'
else:
return 'cpu'
def put_watermark(img, wm_encoder=None):
if wm_encoder is not None:
img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
@ -30,10 +41,9 @@ def initialize_model(config, ckpt):
model.load_state_dict(torch.load(ckpt)["state_dict"], strict=False)
device = torch.device(
"cuda") if torch.cuda.is_available() else torch.device("cpu")
device = torch.device(get_device())
model = model.to(device)
sampler = DDIMSampler(model)
sampler = DDIMSampler(model, device)
return sampler
@ -67,8 +77,7 @@ def make_batch_sd(
def inpaint(sampler, image, mask, prompt, seed, scale, ddim_steps, num_samples=1, w=512, h=512):
device = torch.device(
"cuda") if torch.cuda.is_available() else torch.device("cpu")
device = torch.device(get_device())
model = sampler.model
print("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...")
@ -81,8 +90,9 @@ def inpaint(sampler, image, mask, prompt, seed, scale, ddim_steps, num_samples=1
start_code = torch.from_numpy(start_code).to(
device=device, dtype=torch.float32)
with torch.no_grad(), \
torch.autocast("cuda"):
precision_scope = nullcontext if device.type == 'mps' else torch.autocast
with torch.no_grad(),\
precision_scope(device.type):
batch = make_batch_sd(image, mask, txt=prompt,
device=device, num_samples=num_samples)

View file

@ -6,6 +6,7 @@ from PIL import Image
from omegaconf import OmegaConf
from einops import repeat, rearrange
from pytorch_lightning import seed_everything
from contextlib import nullcontext
from imwatermark import WatermarkEncoder
from scripts.txt2img import put_watermark
@ -17,15 +18,23 @@ from ldm.util import exists, instantiate_from_config
torch.set_grad_enabled(False)
def get_device():
if torch.cuda.is_available():
return 'cuda'
elif torch.backends.mps.is_available():
return 'mps'
else:
return 'cpu'
def initialize_model(config, ckpt):
config = OmegaConf.load(config)
model = instantiate_from_config(config.model)
model.load_state_dict(torch.load(ckpt)["state_dict"], strict=False)
device = torch.device(
"cuda") if torch.cuda.is_available() else torch.device("cpu")
device = get_device()
model = model.to(device)
sampler = DDIMSampler(model)
sampler = DDIMSampler(model, device)
return sampler
@ -54,8 +63,7 @@ def make_noise_augmentation(model, batch, noise_level=None):
def paint(sampler, image, prompt, seed, scale, h, w, steps, num_samples=1, callback=None, eta=0., noise_level=None):
device = torch.device(
"cuda") if torch.cuda.is_available() else torch.device("cpu")
device = torch.device(get_device())
model = sampler.model
seed_everything(seed)
prng = np.random.RandomState(seed)
@ -67,8 +75,9 @@ def paint(sampler, image, prompt, seed, scale, h, w, steps, num_samples=1, callb
wm = "SDV2"
wm_encoder = WatermarkEncoder()
wm_encoder.set_watermark('bytes', wm.encode('utf-8'))
precision_scope = nullcontext if device.type == 'mps' else torch.autocast
with torch.no_grad(),\
torch.autocast("cuda"):
precision_scope(device.type):
batch = make_batch_sd(
image, txt=prompt, device=device, num_samples=num_samples)
c = model.cond_stage_model.encode(batch["txt"])

View file

@ -16,16 +16,15 @@ from pytorch_lightning import seed_everything
from imwatermark import WatermarkEncoder
from ldm import global_opt as g
from scripts.txt2img import put_watermark
from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
def get_device():
if(torch.cuda.is_available()):
if torch.cuda.is_available():
return 'cuda'
elif(torch.backends.mps.is_available()):
elif torch.backends.mps.is_available():
return 'mps'
else:
return 'cpu'
@ -197,7 +196,7 @@ def main():
default="autocast"
)
opt = g.opt = parser.parse_args()
opt = parser.parse_args()
seed_everything(opt.seed)
config = OmegaConf.load(f"{opt.config}")

View file

@ -6,6 +6,7 @@ from PIL import Image
from omegaconf import OmegaConf
from einops import repeat, rearrange
from pytorch_lightning import seed_everything
from contextlib import nullcontext
from imwatermark import WatermarkEncoder
from scripts.txt2img import put_watermark
@ -16,15 +17,25 @@ from ldm.data.util import AddMiDaS
torch.set_grad_enabled(False)
def get_device():
if torch.cuda.is_available():
return 'cuda'
elif torch.backends.mps.is_available():
return 'mps'
else:
return 'cpu'
@st.cache(allow_output_mutation=True)
def initialize_model(config, ckpt):
config = OmegaConf.load(config)
model = instantiate_from_config(config.model)
model.load_state_dict(torch.load(ckpt)["state_dict"], strict=False)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device = torch.device(get_device())
model = model.to(device)
sampler = DDIMSampler(model)
sampler = DDIMSampler(model, device)
return sampler
@ -52,7 +63,7 @@ def make_batch_sd(
def paint(sampler, image, prompt, t_enc, seed, scale, num_samples=1, callback=None,
do_full_sample=False):
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device = torch.device(get_device())
model = sampler.model
seed_everything(seed)
@ -61,8 +72,9 @@ def paint(sampler, image, prompt, t_enc, seed, scale, num_samples=1, callback=No
wm_encoder = WatermarkEncoder()
wm_encoder.set_watermark('bytes', wm.encode('utf-8'))
precision_scope = nullcontext if device.type == 'mps' else torch.autocast
with torch.no_grad(),\
torch.autocast("cuda"):
precision_scope(device.type):
batch = make_batch_sd(image, txt=prompt, device=device, num_samples=num_samples)
z = model.get_first_stage_encoding(model.encode_first_stage(batch[model.first_stage_key])) # move to latent space
c = model.cond_stage_model.encode(batch["txt"])

View file

@ -7,6 +7,7 @@ from PIL import Image
from omegaconf import OmegaConf
from einops import repeat
from streamlit_drawable_canvas import st_canvas
from contextlib import nullcontext
from imwatermark import WatermarkEncoder
from ldm.models.diffusion.ddim import DDIMSampler
@ -16,6 +17,15 @@ from ldm.util import instantiate_from_config
torch.set_grad_enabled(False)
def get_device():
if torch.cuda.is_available():
return 'cuda'
elif torch.backends.mps.is_available():
return 'mps'
else:
return 'cpu'
def put_watermark(img, wm_encoder=None):
if wm_encoder is not None:
img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
@ -31,9 +41,9 @@ def initialize_model(config, ckpt):
model.load_state_dict(torch.load(ckpt)["state_dict"], strict=False)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device = torch.device(get_device())
model = model.to(device)
sampler = DDIMSampler(model)
sampler = DDIMSampler(model, device)
return sampler
@ -67,7 +77,7 @@ def make_batch_sd(
def inpaint(sampler, image, mask, prompt, seed, scale, ddim_steps, num_samples=1, w=512, h=512, eta=1.):
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device = get_device()
model = sampler.model
print("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...")
@ -79,8 +89,9 @@ def inpaint(sampler, image, mask, prompt, seed, scale, ddim_steps, num_samples=1
start_code = prng.randn(num_samples, 4, h // 8, w // 8)
start_code = torch.from_numpy(start_code).to(device=device, dtype=torch.float32)
with torch.no_grad(), \
torch.autocast("cuda"):
precision_scope = nullcontext if device.type == 'mps' else torch.autocast
with torch.no_grad(),\
precision_scope(device.type):
batch = make_batch_sd(image, mask, txt=prompt, device=device, num_samples=num_samples)
c = model.cond_stage_model.encode(batch["txt"])

View file

@ -30,6 +30,15 @@ VERSION2SPECS = {
}
def get_device():
if torch.cuda.is_available():
return 'cuda'
elif torch.backends.mps.is_available():
return 'mps'
else:
return 'cpu'
def get_obj_from_str(string, reload=False):
module, cls = string.rsplit(".", 1)
importlib.invalidate_caches()
@ -69,7 +78,7 @@ def load_img(display=True, key=None):
def get_init_img(batch_size=1, key=None):
init_image = load_img(key=key).cuda()
init_image = load_img(key=key).to(get_device())
init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
return init_image
@ -97,7 +106,10 @@ def sample(
only_adm_cond=False
):
batch_size = n_samples
device = torch.device(get_device())
precision_scope = autocast if not use_full_precision else nullcontext
if device.type == 'mps':
precision_scope = nullcontext
# decoderscope = autocast if not use_full_precision else nullcontext
if use_full_precision: st.warning(f"Running {model.__class__.__name__} at full precision.")
if isinstance(prompt, str):
@ -106,7 +118,7 @@ def sample(
outputs = st.empty()
with precision_scope("cuda"):
with precision_scope(device.type):
with model.ema_scope():
all_samples = list()
for n in trange(n_runs, desc="Sampling"):
@ -208,6 +220,7 @@ def init(version="Stable unCLIP-L", load_karlo_prior=False):
from ldm.modules.karlo.kakao.sampler import T2ISampler
st.info("Loading full KARLO..")
karlo = T2ISampler.from_pretrained(
device=get_device(),
root_dir="checkpoints/karlo_models",
clip_model_path="ViT-L-14.pt",
clip_stat_path="ViT-L-14_stats.th",
@ -227,6 +240,7 @@ def init(version="Stable unCLIP-L", load_karlo_prior=False):
from ldm.modules.karlo.kakao.sampler import PriorSampler
st.info("Loading KARLO CLIP prior...")
karlo_prior = PriorSampler.from_pretrained(
device=get_device(),
root_dir="checkpoints/karlo_models",
clip_model_path="ViT-L-14.pt",
clip_stat_path="ViT-L-14_stats.th",
@ -263,7 +277,7 @@ def load_model_from_config(config, ckpt, verbose=False, vae_sd=None):
print("unexpected keys:")
print(u)
model.cuda()
model.to(get_device())
model.eval()
print(f"Loaded global step {global_step}")
return model, msg
@ -301,7 +315,7 @@ if __name__ == "__main__":
pass
else:
if sampler == "DPM":
sampler = DPMSolverSampler(state["model"])
sampler = DPMSolverSampler(state["model"], device=get_device())
elif sampler == "DDIM":
sampler = DDIMSampler(state["model"])
else:

View file

@ -6,6 +6,7 @@ from PIL import Image
from omegaconf import OmegaConf
from einops import repeat, rearrange
from pytorch_lightning import seed_everything
from contextlib import nullcontext
from imwatermark import WatermarkEncoder
from scripts.txt2img import put_watermark
@ -17,15 +18,24 @@ from ldm.util import exists, instantiate_from_config
torch.set_grad_enabled(False)
def get_device():
if torch.cuda.is_available():
return 'cuda'
elif torch.backends.mps.is_available():
return 'mps'
else:
return 'cpu'
@st.cache(allow_output_mutation=True)
def initialize_model(config, ckpt):
config = OmegaConf.load(config)
model = instantiate_from_config(config.model)
model.load_state_dict(torch.load(ckpt)["state_dict"], strict=False)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device = torch.device(get_device())
model = model.to(device)
sampler = DDIMSampler(model)
sampler = DDIMSampler(model, device)
return sampler
@ -53,7 +63,7 @@ def make_noise_augmentation(model, batch, noise_level=None):
def paint(sampler, image, prompt, seed, scale, h, w, steps, num_samples=1, callback=None, eta=0., noise_level=None):
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device = torch.device(get_device())
model = sampler.model
seed_everything(seed)
prng = np.random.RandomState(seed)
@ -64,8 +74,9 @@ def paint(sampler, image, prompt, seed, scale, h, w, steps, num_samples=1, callb
wm = "SDV2"
wm_encoder = WatermarkEncoder()
wm_encoder.set_watermark('bytes', wm.encode('utf-8'))
precision_scope = nullcontext if device.type == 'mps' else torch.autocast
with torch.no_grad(),\
torch.autocast("cuda"):
precision_scope(device.type):
batch = make_batch_sd(image, txt=prompt, device=device, num_samples=num_samples)
c = model.cond_stage_model.encode(batch["txt"])
c_cat = list()

View file

@ -13,7 +13,6 @@ from torch import autocast
from contextlib import nullcontext
from imwatermark import WatermarkEncoder
from ldm import global_opt as g
from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
@ -22,9 +21,9 @@ from ldm.models.diffusion.dpm_solver import DPMSolverSampler
torch.set_grad_enabled(False)
def get_device():
if(torch.cuda.is_available()):
if torch.cuda.is_available():
return 'cuda'
elif(torch.backends.mps.is_available()):
elif torch.backends.mps.is_available():
return 'mps'
else:
return 'cpu'
@ -389,5 +388,5 @@ def main(opt):
if __name__ == "__main__":
opt = g.opt = parse_args()
opt = parse_args()
main(opt)