diff --git a/README.md b/README.md index f413068..0cc7471 100644 --- a/README.md +++ b/README.md @@ -137,6 +137,41 @@ Note: The inference config for all model versions is designed to be used with EM For this reason `use_ema=False` is set in the configuration, otherwise the code will try to switch from non-EMA to EMA weights. +#### Enable Intel® Extension for PyTorch* optimizations in Text-to-Image script + +If you're planning on running Text-to-Image on Intel® CPU, try to sample an image with TorchScript and Intel® Extension for PyTorch* optimizations. Intel® Extension for PyTorch* extends PyTorch by enabling up-to-date features optimizations for an extra performance boost on Intel® hardware. It can optimize memory layout of the operators to Channel Last memory format, which is generally beneficial for Intel CPUs, take advantage of the most advanced instruction set available on a machine, optimize operators and many more. + +**Prerequisites** + +Before running the script, make sure you have all needed libraries installed. (the optimization was checked on `Ubuntu 20.04`). Install [jemalloc](https://github.com/jemalloc/jemalloc), [numactl](https://linux.die.net/man/8/numactl), Intel® OpenMP and Intel® Extension for PyTorch*. + +```bash +apt-get install numactl libjemalloc-dev +pip install intel-openmp +pip install intel_extension_for_pytorch -f https://software.intel.com/ipex-whl-stable +``` + +To sample from the _SD2.1-v_ model with TorchScript+IPEX optimizations, run the following. Remember to specify desired number of instances you want to run the program on ([more](https://github.com/intel/intel-extension-for-pytorch/blob/master/intel_extension_for_pytorch/cpu/launch.py#L48)). + +``` +MALLOC_CONF=oversize_threshold:1,background_thread:true,metadata_thp:auto,dirty_decay_ms:9000000000,muzzy_decay_ms:9000000000 python -m intel_extension_for_pytorch.cpu.launch --ninstance --enable_jemalloc scripts/txt2img.py --prompt \"a corgi is playing guitar, oil on canvas\" --ckpt --config configs/stable-diffusion/intel/v2-inference-v-fp32.yaml --H 768 --W 768 --precision full --device cpu --torchscript --ipex +``` + +To sample from the base model with IPEX optimizations, use + +``` +MALLOC_CONF=oversize_threshold:1,background_thread:true,metadata_thp:auto,dirty_decay_ms:9000000000,muzzy_decay_ms:9000000000 python -m intel_extension_for_pytorch.cpu.launch --ninstance --enable_jemalloc scripts/txt2img.py --prompt \"a corgi is playing guitar, oil on canvas\" --ckpt --config configs/stable-diffusion/intel/v2-inference-fp32.yaml --n_samples 1 --n_iter 4 --precision full --device cpu --torchscript --ipex +``` + +If you're using a CPU that supports `bfloat16`, consider sample from the model with bfloat16 enabled for a performance boost, like so + +```bash +# SD2.1-v +MALLOC_CONF=oversize_threshold:1,background_thread:true,metadata_thp:auto,dirty_decay_ms:9000000000,muzzy_decay_ms:9000000000 python -m intel_extension_for_pytorch.cpu.launch --ninstance --enable_jemalloc scripts/txt2img.py --prompt \"a corgi is playing guitar, oil on canvas\" --ckpt --config configs/stable-diffusion/intel/v2-inference-v-bf16.yaml --H 768 --W 768 --precision full --device cpu --torchscript --ipex --bf16 +# SD2.1-base +MALLOC_CONF=oversize_threshold:1,background_thread:true,metadata_thp:auto,dirty_decay_ms:9000000000,muzzy_decay_ms:9000000000 python -m intel_extension_for_pytorch.cpu.launch --ninstance --enable_jemalloc scripts/txt2img.py --prompt \"a corgi is playing guitar, oil on canvas\" --ckpt --config configs/stable-diffusion/intel/v2-inference-bf16.yaml --precision full --device cpu --torchscript --ipex --bf16 +``` + ### Image Modification with Stable Diffusion ![depth2img-stable2](assets/stable-samples/depth2img/merged-0000.png) diff --git a/configs/stable-diffusion/intel/v2-inference-bf16.yaml b/configs/stable-diffusion/intel/v2-inference-bf16.yaml new file mode 100644 index 0000000..66f0dbd --- /dev/null +++ b/configs/stable-diffusion/intel/v2-inference-bf16.yaml @@ -0,0 +1,71 @@ +# Copyright (C) 2022 Intel Corporation +# SPDX-License-Identifier: MIT + +model: + base_learning_rate: 1.0e-4 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "jpg" + cond_stage_key: "txt" + image_size: 64 + channels: 4 + cond_stage_trainable: false + conditioning_key: crossattn + monitor: val/loss_simple_ema + scale_factor: 0.18215 + use_ema: False # we set this to false because this is an inference only config + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + use_checkpoint: False + use_fp16: False + use_bf16: True + image_size: 32 # unused + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_head_channels: 64 # need to fix for flash-attn + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: 1 + context_dim: 1024 + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + #attn_type: "vanilla-xformers" + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder + params: + freeze: True + layer: "penultimate" diff --git a/configs/stable-diffusion/intel/v2-inference-fp32.yaml b/configs/stable-diffusion/intel/v2-inference-fp32.yaml new file mode 100644 index 0000000..7b66ac8 --- /dev/null +++ b/configs/stable-diffusion/intel/v2-inference-fp32.yaml @@ -0,0 +1,70 @@ +# Copyright (C) 2022 Intel Corporation +# SPDX-License-Identifier: MIT + +model: + base_learning_rate: 1.0e-4 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "jpg" + cond_stage_key: "txt" + image_size: 64 + channels: 4 + cond_stage_trainable: false + conditioning_key: crossattn + monitor: val/loss_simple_ema + scale_factor: 0.18215 + use_ema: False # we set this to false because this is an inference only config + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + use_checkpoint: False + use_fp16: False + image_size: 32 # unused + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_head_channels: 64 # need to fix for flash-attn + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: 1 + context_dim: 1024 + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + #attn_type: "vanilla-xformers" + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder + params: + freeze: True + layer: "penultimate" diff --git a/configs/stable-diffusion/intel/v2-inference-v-bf16.yaml b/configs/stable-diffusion/intel/v2-inference-v-bf16.yaml new file mode 100644 index 0000000..2b4b0e6 --- /dev/null +++ b/configs/stable-diffusion/intel/v2-inference-v-bf16.yaml @@ -0,0 +1,72 @@ +# Copyright (C) 2022 Intel Corporation +# SPDX-License-Identifier: MIT + +model: + base_learning_rate: 1.0e-4 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + parameterization: "v" + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "jpg" + cond_stage_key: "txt" + image_size: 64 + channels: 4 + cond_stage_trainable: false + conditioning_key: crossattn + monitor: val/loss_simple_ema + scale_factor: 0.18215 + use_ema: False # we set this to false because this is an inference only config + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + use_checkpoint: False + use_fp16: False + use_bf16: True + image_size: 32 # unused + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_head_channels: 64 # need to fix for flash-attn + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: 1 + context_dim: 1024 + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + #attn_type: "vanilla-xformers" + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder + params: + freeze: True + layer: "penultimate" diff --git a/configs/stable-diffusion/intel/v2-inference-v-fp32.yaml b/configs/stable-diffusion/intel/v2-inference-v-fp32.yaml new file mode 100644 index 0000000..8ccd92e --- /dev/null +++ b/configs/stable-diffusion/intel/v2-inference-v-fp32.yaml @@ -0,0 +1,71 @@ +# Copyright (C) 2022 Intel Corporation +# SPDX-License-Identifier: MIT + +model: + base_learning_rate: 1.0e-4 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + parameterization: "v" + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "jpg" + cond_stage_key: "txt" + image_size: 64 + channels: 4 + cond_stage_trainable: false + conditioning_key: crossattn + monitor: val/loss_simple_ema + scale_factor: 0.18215 + use_ema: False # we set this to false because this is an inference only config + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + use_checkpoint: False + use_fp16: False + image_size: 32 # unused + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_head_channels: 64 # need to fix for flash-attn + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: 1 + context_dim: 1024 + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + #attn_type: "vanilla-xformers" + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder + params: + freeze: True + layer: "penultimate" diff --git a/ldm/models/diffusion/ddim.py b/ldm/models/diffusion/ddim.py index 27ead0e..c6cfd57 100644 --- a/ldm/models/diffusion/ddim.py +++ b/ldm/models/diffusion/ddim.py @@ -8,16 +8,17 @@ from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, mak class DDIMSampler(object): - def __init__(self, model, schedule="linear", **kwargs): + def __init__(self, model, schedule="linear", device=torch.device("cuda"), **kwargs): super().__init__() self.model = model self.ddpm_num_timesteps = model.num_timesteps self.schedule = schedule + self.device = device def register_buffer(self, name, attr): if type(attr) == torch.Tensor: - if attr.device != torch.device("cuda"): - attr = attr.to(torch.device("cuda")) + if attr.device != self.device: + attr = attr.to(self.device) setattr(self, name, attr) def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): diff --git a/ldm/models/diffusion/ddpm.py b/ldm/models/diffusion/ddpm.py index 6090212..6d2f5a7 100644 --- a/ldm/models/diffusion/ddpm.py +++ b/ldm/models/diffusion/ddpm.py @@ -1326,7 +1326,13 @@ class DiffusionWrapper(pl.LightningModule): cc = torch.cat(c_crossattn, 1) else: cc = c_crossattn - out = self.diffusion_model(x, t, context=cc) + if hasattr(self, "scripted_diffusion_model"): + # TorchScript changes names of the arguments + # with argument cc defined as context=cc scripted model will produce + # an error: RuntimeError: forward() is missing value for argument 'argument_3'. + out = self.scripted_diffusion_model(x, t, cc) + else: + out = self.diffusion_model(x, t, context=cc) elif self.conditioning_key == 'hybrid': xc = torch.cat([x] + c_concat, dim=1) cc = torch.cat(c_crossattn, 1) diff --git a/ldm/models/diffusion/dpm_solver/sampler.py b/ldm/models/diffusion/dpm_solver/sampler.py index 7d137b8..4270c61 100644 --- a/ldm/models/diffusion/dpm_solver/sampler.py +++ b/ldm/models/diffusion/dpm_solver/sampler.py @@ -11,16 +11,17 @@ MODEL_TYPES = { class DPMSolverSampler(object): - def __init__(self, model, **kwargs): + def __init__(self, model, device=torch.device("cuda"), **kwargs): super().__init__() self.model = model + self.device = device to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device) self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod)) def register_buffer(self, name, attr): if type(attr) == torch.Tensor: - if attr.device != torch.device("cuda"): - attr = attr.to(torch.device("cuda")) + if attr.device != self.device: + attr = attr.to(self.device) setattr(self, name, attr) @torch.no_grad() diff --git a/ldm/models/diffusion/plms.py b/ldm/models/diffusion/plms.py index 7002a36..9d31b39 100644 --- a/ldm/models/diffusion/plms.py +++ b/ldm/models/diffusion/plms.py @@ -10,16 +10,17 @@ from ldm.models.diffusion.sampling_util import norm_thresholding class PLMSSampler(object): - def __init__(self, model, schedule="linear", **kwargs): + def __init__(self, model, schedule="linear", device=torch.device("cuda"), **kwargs): super().__init__() self.model = model self.ddpm_num_timesteps = model.num_timesteps self.schedule = schedule + self.device = device def register_buffer(self, name, attr): if type(attr) == torch.Tensor: - if attr.device != torch.device("cuda"): - attr = attr.to(torch.device("cuda")) + if attr.device != self.device: + attr = attr.to(self.device) setattr(self, name, attr) def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): diff --git a/ldm/modules/diffusionmodules/openaimodel.py b/ldm/modules/diffusionmodules/openaimodel.py index 7df6b5a..764a34b 100644 --- a/ldm/modules/diffusionmodules/openaimodel.py +++ b/ldm/modules/diffusionmodules/openaimodel.py @@ -454,6 +454,7 @@ class UNetModel(nn.Module): num_classes=None, use_checkpoint=False, use_fp16=False, + use_bf16=False, num_heads=-1, num_head_channels=-1, num_heads_upsample=-1, @@ -518,6 +519,7 @@ class UNetModel(nn.Module): self.num_classes = num_classes self.use_checkpoint = use_checkpoint self.dtype = th.float16 if use_fp16 else th.float32 + self.dtype = th.bfloat16 if use_bf16 else self.dtype self.num_heads = num_heads self.num_head_channels = num_head_channels self.num_heads_upsample = num_heads_upsample diff --git a/scripts/txt2img.py b/scripts/txt2img.py index 1ed42a3..9d955e3 100644 --- a/scripts/txt2img.py +++ b/scripts/txt2img.py @@ -25,7 +25,7 @@ def chunk(it, size): return iter(lambda: tuple(islice(it, size)), ()) -def load_model_from_config(config, ckpt, verbose=False): +def load_model_from_config(config, ckpt, device=torch.device("cuda"), verbose=False): print(f"Loading model from {ckpt}") pl_sd = torch.load(ckpt, map_location="cpu") if "global_step" in pl_sd: @@ -40,7 +40,13 @@ def load_model_from_config(config, ckpt, verbose=False): print("unexpected keys:") print(u) - model.cuda() + if device == torch.device("cuda"): + model.cuda() + elif device == torch.device("cpu"): + model.cpu() + model.cond_stage_model.device = "cpu" + else: + raise ValueError(f"Incorrect device name. Received: {device}") model.eval() return model @@ -171,6 +177,28 @@ def parse_args(): default=1, help="repeat each prompt in file this often", ) + parser.add_argument( + "--device", + type=str, + help="Device on which Stable Diffusion will be run", + choices=["cpu", "cuda"], + default="cpu" + ) + parser.add_argument( + "--torchscript", + action='store_true', + help="Use TorchScript", + ) + parser.add_argument( + "--ipex", + action='store_true', + help="Use Intel® Extension for PyTorch*", + ) + parser.add_argument( + "--bf16", + action='store_true', + help="Use bfloat16", + ) opt = parser.parse_args() return opt @@ -187,17 +215,15 @@ def main(opt): seed_everything(opt.seed) config = OmegaConf.load(f"{opt.config}") - model = load_model_from_config(config, f"{opt.ckpt}") - - device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") - model = model.to(device) + device = torch.device("cuda") if opt.device == "cuda" else torch.device("cpu") + model = load_model_from_config(config, f"{opt.ckpt}", device) if opt.plms: - sampler = PLMSSampler(model) + sampler = PLMSSampler(model, device=device) elif opt.dpm: - sampler = DPMSolverSampler(model) + sampler = DPMSolverSampler(model, device=device) else: - sampler = DDIMSampler(model) + sampler = DDIMSampler(model, device=device) os.makedirs(opt.outdir, exist_ok=True) outpath = opt.outdir @@ -231,9 +257,82 @@ def main(opt): if opt.fixed_code: start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device) - precision_scope = autocast if opt.precision == "autocast" else nullcontext + if opt.torchscript or opt.ipex: + transformer = model.cond_stage_model.model + unet = model.model.diffusion_model + decoder = model.first_stage_model.decoder + additional_context = torch.cpu.amp.autocast() if opt.bf16 else nullcontext() + shape = [opt.C, opt.H // opt.f, opt.W // opt.f] + + if opt.bf16 and not opt.torchscript and not opt.ipex: + raise ValueError('Bfloat16 is supported only for torchscript+ipex') + if opt.bf16 and unet.dtype != torch.bfloat16: + raise ValueError("Use configs/stable-diffusion/intel/ configs with bf16 enabled if " + + "you'd like to use bfloat16 with CPU.") + if unet.dtype == torch.float16 and device == torch.device("cpu"): + raise ValueError("Use configs/stable-diffusion/intel/ configs for your model if you'd like to run it on CPU.") + + if opt.ipex: + import intel_extension_for_pytorch as ipex + bf16_dtype = torch.bfloat16 if opt.bf16 else None + transformer = transformer.to(memory_format=torch.channels_last) + transformer = ipex.optimize(transformer, level="O1", inplace=True) + + unet = unet.to(memory_format=torch.channels_last) + unet = ipex.optimize(unet, level="O1", auto_kernel_selection=True, inplace=True, dtype=bf16_dtype) + + decoder = decoder.to(memory_format=torch.channels_last) + decoder = ipex.optimize(decoder, level="O1", auto_kernel_selection=True, inplace=True, dtype=bf16_dtype) + + if opt.torchscript: + with torch.no_grad(), additional_context: + # get UNET scripted + if unet.use_checkpoint: + raise ValueError("Gradient checkpoint won't work with tracing. " + + "Use configs/stable-diffusion/intel/ configs for your model or disable checkpoint in your config.") + + img_in = torch.ones(2, 4, 96, 96, dtype=torch.float32) + t_in = torch.ones(2, dtype=torch.int64) + context = torch.ones(2, 77, 1024, dtype=torch.float32) + scripted_unet = torch.jit.trace(unet, (img_in, t_in, context)) + scripted_unet = torch.jit.optimize_for_inference(scripted_unet) + print(type(scripted_unet)) + model.model.scripted_diffusion_model = scripted_unet + + # get Decoder for first stage model scripted + samples_ddim = torch.ones(1, 4, 96, 96, dtype=torch.float32) + scripted_decoder = torch.jit.trace(decoder, (samples_ddim)) + scripted_decoder = torch.jit.optimize_for_inference(scripted_decoder) + print(type(scripted_decoder)) + model.first_stage_model.decoder = scripted_decoder + + prompts = data[0] + print("Running a forward pass to initialize optimizations") + uc = None + if opt.scale != 1.0: + uc = model.get_learned_conditioning(batch_size * [""]) + if isinstance(prompts, tuple): + prompts = list(prompts) + + with torch.no_grad(), additional_context: + for _ in range(3): + c = model.get_learned_conditioning(prompts) + samples_ddim, _ = sampler.sample(S=5, + conditioning=c, + batch_size=batch_size, + shape=shape, + verbose=False, + unconditional_guidance_scale=opt.scale, + unconditional_conditioning=uc, + eta=opt.ddim_eta, + x_T=start_code) + print("Running a forward pass for decoder") + for _ in range(3): + x_samples_ddim = model.decode_first_stage(samples_ddim) + + precision_scope = autocast if opt.precision=="autocast" or opt.bf16 else nullcontext with torch.no_grad(), \ - precision_scope("cuda"), \ + precision_scope(opt.device), \ model.ema_scope(): all_samples = list() for n in trange(opt.n_iter, desc="Sampling"):