diff --git a/ldm/global_opt.py b/ldm/global_opt.py deleted file mode 100644 index cb8db3d..0000000 --- a/ldm/global_opt.py +++ /dev/null @@ -1 +0,0 @@ -# Only Import opt \ No newline at end of file diff --git a/ldm/modules/encoders/modules.py b/ldm/modules/encoders/modules.py index 9b8a6d0..84a94b2 100644 --- a/ldm/modules/encoders/modules.py +++ b/ldm/modules/encoders/modules.py @@ -6,10 +6,17 @@ from torch.utils.checkpoint import checkpoint from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel import open_clip -from ldm import global_opt as g from ldm.util import default, count_params, autocast +def get_device(): + if torch.cuda.is_available(): + return 'cuda' + elif torch.backends.mps.is_available(): + return 'mps' + else: + return 'cpu' + class AbstractEncoder(nn.Module): def __init__(self): @@ -61,9 +68,9 @@ def disabled_train(self, mode=True): class FrozenT5Embedder(AbstractEncoder): """Uses the T5 transformer encoder for text""" - def __init__(self, version="google/t5-v1_1-large", device=g.opt.device, max_length=77, freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl + def __init__(self, version="google/t5-v1_1-large", device=get_device(), max_length=77, freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl super().__init__() - self.tokenizer = T5Tokenizer.from_pretrained(version) + self.tokenizer = T5Tokenizer.from_pretrained(version, device) self.transformer = T5EncoderModel.from_pretrained(version) self.device = device self.max_length = max_length # TODO: typical value? @@ -98,7 +105,7 @@ class FrozenCLIPEmbedder(AbstractEncoder): "hidden" ] - def __init__(self, version="openai/clip-vit-large-patch14", device=g.opt.device, max_length=77, + def __init__(self, version="openai/clip-vit-large-patch14", device=get_device(), max_length=77, freeze=True, layer="last", layer_idx=None): # clip-vit-base-patch32 super().__init__() assert layer in self.LAYERS @@ -142,7 +149,7 @@ class ClipImageEmbedder(nn.Module): self, model, jit=False, - device='cuda' if torch.cuda.is_available() else 'cpu', + device=get_device(), antialias=True, ucg_rate=0. ): @@ -185,7 +192,7 @@ class FrozenOpenCLIPEmbedder(AbstractEncoder): "penultimate" ] - def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device=g.opt.device, max_length=77, + def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device=get_device(), max_length=77, freeze=True, layer="last"): super().__init__() assert layer in self.LAYERS @@ -243,7 +250,7 @@ class FrozenOpenCLIPImageEmbedder(AbstractEncoder): Uses the OpenCLIP vision transformer encoder for images """ - def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77, + def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device=get_device(), max_length=77, freeze=True, layer="pooled", antialias=True, ucg_rate=0.): super().__init__() model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), @@ -298,7 +305,7 @@ class FrozenOpenCLIPImageEmbedder(AbstractEncoder): class FrozenCLIPT5Encoder(AbstractEncoder): - def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device=g.opt.device, + def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device=get_device(), clip_max_length=77, t5_max_length=77): super().__init__() self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length) diff --git a/ldm/modules/karlo/kakao/sampler.py b/ldm/modules/karlo/kakao/sampler.py index b56bf2f..a827b8b 100644 --- a/ldm/modules/karlo/kakao/sampler.py +++ b/ldm/modules/karlo/kakao/sampler.py @@ -11,6 +11,7 @@ import torch import torchvision.transforms.functional as TVF from torchvision.transforms import InterpolationMode +from contextlib import nullcontext from .template import BaseSampler, CKPT_PATH @@ -31,6 +32,7 @@ class T2ISampler(BaseSampler): @classmethod def from_pretrained( cls, + device, root_dir: str, clip_model_path: str, clip_stat_path: str, @@ -41,7 +43,7 @@ class T2ISampler(BaseSampler): root_dir=root_dir, sampling_type=sampling_type, ) - model.load_clip(clip_model_path) + model.load_clip(clip_model_path, device) model.load_prior( f"{CKPT_PATH['prior']}", clip_stat_path=clip_stat_path, @@ -60,10 +62,10 @@ class T2ISampler(BaseSampler): prompts_batch = [prompt for _ in range(bsz)] prior_cf_scales_batch = [self._prior_cf_scale] * len(prompts_batch) - prior_cf_scales_batch = torch.tensor(prior_cf_scales_batch, device="cuda") + prior_cf_scales_batch = torch.tensor(prior_cf_scales_batch, device=self.device) decoder_cf_scales_batch = [self._decoder_cf_scale] * len(prompts_batch) - decoder_cf_scales_batch = torch.tensor(decoder_cf_scales_batch, device="cuda") + decoder_cf_scales_batch = torch.tensor(decoder_cf_scales_batch, device=self.device) """ Get CLIP text feature """ clip_model = self._clip @@ -79,7 +81,7 @@ class T2ISampler(BaseSampler): tok = torch.cat([tok, cf_token], dim=0) mask = torch.cat([mask, cf_mask], dim=0) - tok, mask = tok.to(device="cuda"), mask.to(device="cuda") + tok, mask = tok.to(device=self.device), mask.to(device=self.device) txt_feat, txt_feat_seq = clip_model.encode_text(tok) return ( @@ -99,7 +101,8 @@ class T2ISampler(BaseSampler): progressive_mode=None, ) -> Iterator[torch.Tensor]: assert progressive_mode in ("loop", "stage", "final") - with torch.no_grad(), torch.cuda.amp.autocast(): + precision_scope = nullcontext if self.device.type == 'mps' else torch.cuda.amp.autocast + with torch.no_grad(), precision_scope(self.device.type): ( prompts_batch, prior_cf_scales_batch, @@ -181,6 +184,7 @@ class PriorSampler(BaseSampler): @classmethod def from_pretrained( cls, + device, root_dir: str, clip_model_path: str, clip_stat_path: str, @@ -190,7 +194,7 @@ class PriorSampler(BaseSampler): root_dir=root_dir, sampling_type=sampling_type, ) - model.load_clip(clip_model_path) + model.load_clip(clip_model_path, device) model.load_prior( f"{CKPT_PATH['prior']}", clip_stat_path=clip_stat_path, @@ -207,10 +211,10 @@ class PriorSampler(BaseSampler): prompts_batch = [prompt for _ in range(bsz)] prior_cf_scales_batch = [self._prior_cf_scale] * len(prompts_batch) - prior_cf_scales_batch = torch.tensor(prior_cf_scales_batch, device="cuda") + prior_cf_scales_batch = torch.tensor(prior_cf_scales_batch, device=self.device) decoder_cf_scales_batch = [self._decoder_cf_scale] * len(prompts_batch) - decoder_cf_scales_batch = torch.tensor(decoder_cf_scales_batch, device="cuda") + decoder_cf_scales_batch = torch.tensor(decoder_cf_scales_batch, device=self.device) """ Get CLIP text feature """ clip_model = self._clip @@ -226,7 +230,7 @@ class PriorSampler(BaseSampler): tok = torch.cat([tok, cf_token], dim=0) mask = torch.cat([mask, cf_mask], dim=0) - tok, mask = tok.to(device="cuda"), mask.to(device="cuda") + tok, mask = tok.to(device=self.device), mask.to(device=self.device) txt_feat, txt_feat_seq = clip_model.encode_text(tok) return ( @@ -246,7 +250,8 @@ class PriorSampler(BaseSampler): progressive_mode=None, ) -> Iterator[torch.Tensor]: assert progressive_mode in ("loop", "stage", "final") - with torch.no_grad(), torch.cuda.amp.autocast(): + precision_scope = nullcontext if self.device.type == 'mps' else torch.autocast + with torch.no_grad(), precision_scope(self.device.type): ( prompts_batch, prior_cf_scales_batch, diff --git a/ldm/modules/karlo/kakao/template.py b/ldm/modules/karlo/kakao/template.py index 949e80e..24c04ac 100644 --- a/ldm/modules/karlo/kakao/template.py +++ b/ldm/modules/karlo/kakao/template.py @@ -73,15 +73,16 @@ class BaseSampler: return line - def load_clip(self, clip_path: str): + def load_clip(self, clip_path: str, device): clip = CustomizedCLIP.load_from_checkpoint( os.path.join(self._root_dir, clip_path) ) clip = torch.jit.script(clip) - clip.cuda() + clip.to(device) clip.eval() self._clip = clip + self.device = device self._tokenizer = CustomizedTokenizer() def load_prior( @@ -105,7 +106,7 @@ class BaseSampler: os.path.join(self._root_dir, ckpt_path), strict=True, ) - prior.cuda() + prior.to(self.device) prior.eval() logging.info("done.") @@ -121,7 +122,7 @@ class BaseSampler: os.path.join(self._root_dir, ckpt_path), strict=True, ) - decoder.cuda() + decoder.to(self.device) decoder.eval() logging.info("done.") @@ -134,7 +135,7 @@ class BaseSampler: sr = self._SR256_CLASS.load_from_checkpoint( config, os.path.join(self._root_dir, ckpt_path), strict=True ) - sr.cuda() + sr.to(self.device) sr.eval() logging.info("done.") diff --git a/scripts/gradio/depth2img.py b/scripts/gradio/depth2img.py index c791a4d..ea563cb 100644 --- a/scripts/gradio/depth2img.py +++ b/scripts/gradio/depth2img.py @@ -6,6 +6,7 @@ from PIL import Image from omegaconf import OmegaConf from einops import repeat, rearrange from pytorch_lightning import seed_everything +from contextlib import nullcontext from imwatermark import WatermarkEncoder from scripts.txt2img import put_watermark @@ -16,15 +17,23 @@ from ldm.data.util import AddMiDaS torch.set_grad_enabled(False) +def get_device(): + if torch.cuda.is_available(): + return 'cuda' + elif torch.backends.mps.is_available(): + return 'mps' + else: + return 'cpu' + + def initialize_model(config, ckpt): config = OmegaConf.load(config) model = instantiate_from_config(config.model) model.load_state_dict(torch.load(ckpt)["state_dict"], strict=False) - device = torch.device( - "cuda") if torch.cuda.is_available() else torch.device("cpu") + device = torch.device(get_device()) model = model.to(device) - sampler = DDIMSampler(model) + sampler = DDIMSampler(model, device) return sampler @@ -54,8 +63,7 @@ def make_batch_sd( def paint(sampler, image, prompt, t_enc, seed, scale, num_samples=1, callback=None, do_full_sample=False): - device = torch.device( - "cuda") if torch.cuda.is_available() else torch.device("cpu") + device = torch.device(get_device()) model = sampler.model seed_everything(seed) @@ -64,8 +72,9 @@ def paint(sampler, image, prompt, t_enc, seed, scale, num_samples=1, callback=No wm_encoder = WatermarkEncoder() wm_encoder.set_watermark('bytes', wm.encode('utf-8')) + precision_scope = nullcontext if device.type == 'mps' else torch.autocast with torch.no_grad(),\ - torch.autocast("cuda"): + precision_scope(device.type): batch = make_batch_sd( image, txt=prompt, device=device, num_samples=num_samples) z = model.get_first_stage_encoding(model.encode_first_stage( diff --git a/scripts/gradio/inpainting.py b/scripts/gradio/inpainting.py index 09d44f3..99aba1d 100644 --- a/scripts/gradio/inpainting.py +++ b/scripts/gradio/inpainting.py @@ -6,6 +6,7 @@ import gradio as gr from PIL import Image from omegaconf import OmegaConf from einops import repeat +from contextlib import nullcontext from imwatermark import WatermarkEncoder from pathlib import Path @@ -16,6 +17,16 @@ from ldm.util import instantiate_from_config torch.set_grad_enabled(False) +def get_device(): + if torch.cuda.is_available(): + return 'cuda' + elif torch.backends.mps.is_available(): + return 'mps' + else: + return 'cpu' + + + def put_watermark(img, wm_encoder=None): if wm_encoder is not None: img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) @@ -30,10 +41,9 @@ def initialize_model(config, ckpt): model.load_state_dict(torch.load(ckpt)["state_dict"], strict=False) - device = torch.device( - "cuda") if torch.cuda.is_available() else torch.device("cpu") + device = torch.device(get_device()) model = model.to(device) - sampler = DDIMSampler(model) + sampler = DDIMSampler(model, device) return sampler @@ -67,8 +77,7 @@ def make_batch_sd( def inpaint(sampler, image, mask, prompt, seed, scale, ddim_steps, num_samples=1, w=512, h=512): - device = torch.device( - "cuda") if torch.cuda.is_available() else torch.device("cpu") + device = torch.device(get_device()) model = sampler.model print("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...") @@ -81,8 +90,9 @@ def inpaint(sampler, image, mask, prompt, seed, scale, ddim_steps, num_samples=1 start_code = torch.from_numpy(start_code).to( device=device, dtype=torch.float32) - with torch.no_grad(), \ - torch.autocast("cuda"): + precision_scope = nullcontext if device.type == 'mps' else torch.autocast + with torch.no_grad(),\ + precision_scope(device.type): batch = make_batch_sd(image, mask, txt=prompt, device=device, num_samples=num_samples) diff --git a/scripts/gradio/superresolution.py b/scripts/gradio/superresolution.py index 3d08fbf..fe50627 100644 --- a/scripts/gradio/superresolution.py +++ b/scripts/gradio/superresolution.py @@ -6,6 +6,7 @@ from PIL import Image from omegaconf import OmegaConf from einops import repeat, rearrange from pytorch_lightning import seed_everything +from contextlib import nullcontext from imwatermark import WatermarkEncoder from scripts.txt2img import put_watermark @@ -17,15 +18,23 @@ from ldm.util import exists, instantiate_from_config torch.set_grad_enabled(False) +def get_device(): + if torch.cuda.is_available(): + return 'cuda' + elif torch.backends.mps.is_available(): + return 'mps' + else: + return 'cpu' + + def initialize_model(config, ckpt): config = OmegaConf.load(config) model = instantiate_from_config(config.model) model.load_state_dict(torch.load(ckpt)["state_dict"], strict=False) - device = torch.device( - "cuda") if torch.cuda.is_available() else torch.device("cpu") + device = get_device() model = model.to(device) - sampler = DDIMSampler(model) + sampler = DDIMSampler(model, device) return sampler @@ -54,8 +63,7 @@ def make_noise_augmentation(model, batch, noise_level=None): def paint(sampler, image, prompt, seed, scale, h, w, steps, num_samples=1, callback=None, eta=0., noise_level=None): - device = torch.device( - "cuda") if torch.cuda.is_available() else torch.device("cpu") + device = torch.device(get_device()) model = sampler.model seed_everything(seed) prng = np.random.RandomState(seed) @@ -67,8 +75,9 @@ def paint(sampler, image, prompt, seed, scale, h, w, steps, num_samples=1, callb wm = "SDV2" wm_encoder = WatermarkEncoder() wm_encoder.set_watermark('bytes', wm.encode('utf-8')) + precision_scope = nullcontext if device.type == 'mps' else torch.autocast with torch.no_grad(),\ - torch.autocast("cuda"): + precision_scope(device.type): batch = make_batch_sd( image, txt=prompt, device=device, num_samples=num_samples) c = model.cond_stage_model.encode(batch["txt"]) diff --git a/scripts/img2img.py b/scripts/img2img.py index dcf9018..2ff3c7c 100644 --- a/scripts/img2img.py +++ b/scripts/img2img.py @@ -16,16 +16,15 @@ from pytorch_lightning import seed_everything from imwatermark import WatermarkEncoder -from ldm import global_opt as g from scripts.txt2img import put_watermark from ldm.util import instantiate_from_config from ldm.models.diffusion.ddim import DDIMSampler def get_device(): - if(torch.cuda.is_available()): + if torch.cuda.is_available(): return 'cuda' - elif(torch.backends.mps.is_available()): + elif torch.backends.mps.is_available(): return 'mps' else: return 'cpu' @@ -197,7 +196,7 @@ def main(): default="autocast" ) - opt = g.opt = parser.parse_args() + opt = parser.parse_args() seed_everything(opt.seed) config = OmegaConf.load(f"{opt.config}") diff --git a/scripts/streamlit/depth2img.py b/scripts/streamlit/depth2img.py index 7f80223..b21033d 100644 --- a/scripts/streamlit/depth2img.py +++ b/scripts/streamlit/depth2img.py @@ -6,6 +6,7 @@ from PIL import Image from omegaconf import OmegaConf from einops import repeat, rearrange from pytorch_lightning import seed_everything +from contextlib import nullcontext from imwatermark import WatermarkEncoder from scripts.txt2img import put_watermark @@ -16,15 +17,25 @@ from ldm.data.util import AddMiDaS torch.set_grad_enabled(False) + +def get_device(): + if torch.cuda.is_available(): + return 'cuda' + elif torch.backends.mps.is_available(): + return 'mps' + else: + return 'cpu' + + @st.cache(allow_output_mutation=True) def initialize_model(config, ckpt): config = OmegaConf.load(config) model = instantiate_from_config(config.model) model.load_state_dict(torch.load(ckpt)["state_dict"], strict=False) - device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + device = torch.device(get_device()) model = model.to(device) - sampler = DDIMSampler(model) + sampler = DDIMSampler(model, device) return sampler @@ -52,7 +63,7 @@ def make_batch_sd( def paint(sampler, image, prompt, t_enc, seed, scale, num_samples=1, callback=None, do_full_sample=False): - device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + device = torch.device(get_device()) model = sampler.model seed_everything(seed) @@ -61,8 +72,9 @@ def paint(sampler, image, prompt, t_enc, seed, scale, num_samples=1, callback=No wm_encoder = WatermarkEncoder() wm_encoder.set_watermark('bytes', wm.encode('utf-8')) + precision_scope = nullcontext if device.type == 'mps' else torch.autocast with torch.no_grad(),\ - torch.autocast("cuda"): + precision_scope(device.type): batch = make_batch_sd(image, txt=prompt, device=device, num_samples=num_samples) z = model.get_first_stage_encoding(model.encode_first_stage(batch[model.first_stage_key])) # move to latent space c = model.cond_stage_model.encode(batch["txt"]) diff --git a/scripts/streamlit/inpainting.py b/scripts/streamlit/inpainting.py index c35772f..59add92 100644 --- a/scripts/streamlit/inpainting.py +++ b/scripts/streamlit/inpainting.py @@ -7,6 +7,7 @@ from PIL import Image from omegaconf import OmegaConf from einops import repeat from streamlit_drawable_canvas import st_canvas +from contextlib import nullcontext from imwatermark import WatermarkEncoder from ldm.models.diffusion.ddim import DDIMSampler @@ -16,6 +17,15 @@ from ldm.util import instantiate_from_config torch.set_grad_enabled(False) +def get_device(): + if torch.cuda.is_available(): + return 'cuda' + elif torch.backends.mps.is_available(): + return 'mps' + else: + return 'cpu' + + def put_watermark(img, wm_encoder=None): if wm_encoder is not None: img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) @@ -31,9 +41,9 @@ def initialize_model(config, ckpt): model.load_state_dict(torch.load(ckpt)["state_dict"], strict=False) - device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + device = torch.device(get_device()) model = model.to(device) - sampler = DDIMSampler(model) + sampler = DDIMSampler(model, device) return sampler @@ -67,7 +77,7 @@ def make_batch_sd( def inpaint(sampler, image, mask, prompt, seed, scale, ddim_steps, num_samples=1, w=512, h=512, eta=1.): - device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + device = get_device() model = sampler.model print("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...") @@ -79,8 +89,9 @@ def inpaint(sampler, image, mask, prompt, seed, scale, ddim_steps, num_samples=1 start_code = prng.randn(num_samples, 4, h // 8, w // 8) start_code = torch.from_numpy(start_code).to(device=device, dtype=torch.float32) - with torch.no_grad(), \ - torch.autocast("cuda"): + precision_scope = nullcontext if device.type == 'mps' else torch.autocast + with torch.no_grad(),\ + precision_scope(device.type): batch = make_batch_sd(image, mask, txt=prompt, device=device, num_samples=num_samples) c = model.cond_stage_model.encode(batch["txt"]) diff --git a/scripts/streamlit/stableunclip.py b/scripts/streamlit/stableunclip.py index 122fa9a..94a2738 100644 --- a/scripts/streamlit/stableunclip.py +++ b/scripts/streamlit/stableunclip.py @@ -30,6 +30,15 @@ VERSION2SPECS = { } +def get_device(): + if torch.cuda.is_available(): + return 'cuda' + elif torch.backends.mps.is_available(): + return 'mps' + else: + return 'cpu' + + def get_obj_from_str(string, reload=False): module, cls = string.rsplit(".", 1) importlib.invalidate_caches() @@ -69,7 +78,7 @@ def load_img(display=True, key=None): def get_init_img(batch_size=1, key=None): - init_image = load_img(key=key).cuda() + init_image = load_img(key=key).to(get_device()) init_image = repeat(init_image, '1 ... -> b ...', b=batch_size) return init_image @@ -97,7 +106,10 @@ def sample( only_adm_cond=False ): batch_size = n_samples + device = torch.device(get_device()) precision_scope = autocast if not use_full_precision else nullcontext + if device.type == 'mps': + precision_scope = nullcontext # decoderscope = autocast if not use_full_precision else nullcontext if use_full_precision: st.warning(f"Running {model.__class__.__name__} at full precision.") if isinstance(prompt, str): @@ -106,7 +118,7 @@ def sample( outputs = st.empty() - with precision_scope("cuda"): + with precision_scope(device.type): with model.ema_scope(): all_samples = list() for n in trange(n_runs, desc="Sampling"): @@ -208,6 +220,7 @@ def init(version="Stable unCLIP-L", load_karlo_prior=False): from ldm.modules.karlo.kakao.sampler import T2ISampler st.info("Loading full KARLO..") karlo = T2ISampler.from_pretrained( + device=get_device(), root_dir="checkpoints/karlo_models", clip_model_path="ViT-L-14.pt", clip_stat_path="ViT-L-14_stats.th", @@ -227,6 +240,7 @@ def init(version="Stable unCLIP-L", load_karlo_prior=False): from ldm.modules.karlo.kakao.sampler import PriorSampler st.info("Loading KARLO CLIP prior...") karlo_prior = PriorSampler.from_pretrained( + device=get_device(), root_dir="checkpoints/karlo_models", clip_model_path="ViT-L-14.pt", clip_stat_path="ViT-L-14_stats.th", @@ -263,7 +277,7 @@ def load_model_from_config(config, ckpt, verbose=False, vae_sd=None): print("unexpected keys:") print(u) - model.cuda() + model.to(get_device()) model.eval() print(f"Loaded global step {global_step}") return model, msg @@ -301,7 +315,7 @@ if __name__ == "__main__": pass else: if sampler == "DPM": - sampler = DPMSolverSampler(state["model"]) + sampler = DPMSolverSampler(state["model"], device=get_device()) elif sampler == "DDIM": sampler = DDIMSampler(state["model"]) else: diff --git a/scripts/streamlit/superresolution.py b/scripts/streamlit/superresolution.py index c1172b0..eba5ce4 100644 --- a/scripts/streamlit/superresolution.py +++ b/scripts/streamlit/superresolution.py @@ -6,6 +6,7 @@ from PIL import Image from omegaconf import OmegaConf from einops import repeat, rearrange from pytorch_lightning import seed_everything +from contextlib import nullcontext from imwatermark import WatermarkEncoder from scripts.txt2img import put_watermark @@ -17,15 +18,24 @@ from ldm.util import exists, instantiate_from_config torch.set_grad_enabled(False) +def get_device(): + if torch.cuda.is_available(): + return 'cuda' + elif torch.backends.mps.is_available(): + return 'mps' + else: + return 'cpu' + + @st.cache(allow_output_mutation=True) def initialize_model(config, ckpt): config = OmegaConf.load(config) model = instantiate_from_config(config.model) model.load_state_dict(torch.load(ckpt)["state_dict"], strict=False) - device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + device = torch.device(get_device()) model = model.to(device) - sampler = DDIMSampler(model) + sampler = DDIMSampler(model, device) return sampler @@ -53,7 +63,7 @@ def make_noise_augmentation(model, batch, noise_level=None): def paint(sampler, image, prompt, seed, scale, h, w, steps, num_samples=1, callback=None, eta=0., noise_level=None): - device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + device = torch.device(get_device()) model = sampler.model seed_everything(seed) prng = np.random.RandomState(seed) @@ -64,8 +74,9 @@ def paint(sampler, image, prompt, seed, scale, h, w, steps, num_samples=1, callb wm = "SDV2" wm_encoder = WatermarkEncoder() wm_encoder.set_watermark('bytes', wm.encode('utf-8')) + precision_scope = nullcontext if device.type == 'mps' else torch.autocast with torch.no_grad(),\ - torch.autocast("cuda"): + precision_scope(device.type): batch = make_batch_sd(image, txt=prompt, device=device, num_samples=num_samples) c = model.cond_stage_model.encode(batch["txt"]) c_cat = list() diff --git a/scripts/txt2img.py b/scripts/txt2img.py index f3297dd..132df4b 100644 --- a/scripts/txt2img.py +++ b/scripts/txt2img.py @@ -13,7 +13,6 @@ from torch import autocast from contextlib import nullcontext from imwatermark import WatermarkEncoder -from ldm import global_opt as g from ldm.util import instantiate_from_config from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.plms import PLMSSampler @@ -22,9 +21,9 @@ from ldm.models.diffusion.dpm_solver import DPMSolverSampler torch.set_grad_enabled(False) def get_device(): - if(torch.cuda.is_available()): + if torch.cuda.is_available(): return 'cuda' - elif(torch.backends.mps.is_available()): + elif torch.backends.mps.is_available(): return 'mps' else: return 'cpu' @@ -389,5 +388,5 @@ def main(opt): if __name__ == "__main__": - opt = g.opt = parse_args() + opt = parse_args() main(opt)