adding deepspeed for inference optimization

This commit is contained in:
Viraat 2022-11-29 02:34:50 -08:00
parent 47b6b607fd
commit 5820f54cb0
3 changed files with 26 additions and 0 deletions

View file

@ -13,4 +13,5 @@ transformers==4.19.2
webdataset==0.2.5 webdataset==0.2.5
open-clip-torch==2.7.0 open-clip-torch==2.7.0
gradio==3.11 gradio==3.11
deepspeed==0.7.5
-e . -e .

View file

@ -20,6 +20,8 @@ from scripts.txt2img import put_watermark
from ldm.util import instantiate_from_config from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.ddim import DDIMSampler
import deepspeed
def chunk(it, size): def chunk(it, size):
it = iter(it) it = iter(it)
@ -43,6 +45,16 @@ def load_model_from_config(config, ckpt, verbose=False):
model.cuda() model.cuda()
model.eval() model.eval()
ds_engine = deepspeed.init_inference(model,
mp_size=2,
dtype=torch.half,
checkpoint=None,
replace_method='auto',
replace_with_kernel_inject=True)
model = ds_engine.module
return model return model

View file

@ -17,6 +17,7 @@ from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler from ldm.models.diffusion.plms import PLMSSampler
from ldm.models.diffusion.dpm_solver import DPMSolverSampler from ldm.models.diffusion.dpm_solver import DPMSolverSampler
import deepspeed
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
@ -40,8 +41,20 @@ def load_model_from_config(config, ckpt, verbose=False):
print("unexpected keys:") print("unexpected keys:")
print(u) print(u)
model.cuda() model.cuda()
model.eval() model.eval()
# Initialize the DeepSpeed-Inference engine
ds_engine = deepspeed.init_inference(model,
mp_size=2,
dtype=torch.half,
checkpoint=None,
replace_method='auto',
replace_with_kernel_inject=True)
model = ds_engine.module
return model return model