diff --git a/requirements.txt b/requirements.txt index 2404caa..8bd6f44 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,4 +13,5 @@ transformers==4.19.2 webdataset==0.2.5 open-clip-torch==2.7.0 gradio==3.11 +deepspeed==0.7.5 -e . diff --git a/scripts/img2img.py b/scripts/img2img.py index 9085ba9..6ba1f78 100644 --- a/scripts/img2img.py +++ b/scripts/img2img.py @@ -20,6 +20,8 @@ from scripts.txt2img import put_watermark from ldm.util import instantiate_from_config from ldm.models.diffusion.ddim import DDIMSampler +import deepspeed + def chunk(it, size): it = iter(it) @@ -43,6 +45,16 @@ def load_model_from_config(config, ckpt, verbose=False): model.cuda() model.eval() + + ds_engine = deepspeed.init_inference(model, + mp_size=2, + dtype=torch.half, + checkpoint=None, + replace_method='auto', + replace_with_kernel_inject=True) + + model = ds_engine.module + return model diff --git a/scripts/txt2img.py b/scripts/txt2img.py index 1ed42a3..6258108 100644 --- a/scripts/txt2img.py +++ b/scripts/txt2img.py @@ -17,6 +17,7 @@ from ldm.util import instantiate_from_config from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.plms import PLMSSampler from ldm.models.diffusion.dpm_solver import DPMSolverSampler +import deepspeed torch.set_grad_enabled(False) @@ -39,9 +40,21 @@ def load_model_from_config(config, ckpt, verbose=False): if len(u) > 0 and verbose: print("unexpected keys:") print(u) + model.cuda() model.eval() + + # Initialize the DeepSpeed-Inference engine + ds_engine = deepspeed.init_inference(model, + mp_size=2, + dtype=torch.half, + checkpoint=None, + replace_method='auto', + replace_with_kernel_inject=True) + + model = ds_engine.module + return model