diff --git a/ldm/models/diffusion/dpm_solver/sampler.py b/ldm/models/diffusion/dpm_solver/sampler.py index e4d0d0a..fd6e155 100644 --- a/ldm/models/diffusion/dpm_solver/sampler.py +++ b/ldm/models/diffusion/dpm_solver/sampler.py @@ -1,4 +1,3 @@ -"""SAMPLING ONLY.""" import torch from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver @@ -8,46 +7,87 @@ MODEL_TYPES = { "v": "v" } - class DPMSolverSampler(object): - def __init__(self, model, device=torch.device("cuda"), **kwargs): + def __init__(self, model: torch.nn.Module, device: torch.device = torch.device("cuda"), **kwargs) -> None: + """ + Initialize the DPMSolverSampler. + + Args: + model (torch.nn.Module): The model to use for sampling. + device (torch.device, optional): The device to use. Defaults to torch.device("cuda"). + **kwargs: Additional keyword arguments. + """ super().__init__() self.model = model self.device = device to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device) self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod)) - def register_buffer(self, name, attr): - if type(attr) == torch.Tensor: + def register_buffer(self, name: str, attr: torch.Tensor) -> None: + """ + Register a buffer in the module. + + Args: + name (str): The name of the buffer. + attr (torch.Tensor): The tensor to register as a buffer. + """ + if isinstance(attr, torch.Tensor): if attr.device != self.device: attr = attr.to(self.device) setattr(self, name, attr) @torch.no_grad() def sample(self, - S, - batch_size, - shape, - conditioning=None, - callback=None, - normals_sequence=None, - img_callback=None, - quantize_x0=False, - eta=0., - mask=None, - x0=None, - temperature=1., - noise_dropout=0., - score_corrector=None, - corrector_kwargs=None, - verbose=True, - x_T=None, - log_every_t=100, - unconditional_guidance_scale=1., - unconditional_conditioning=None, - # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... - **kwargs - ): + S: int, + batch_size: int, + shape: tuple, + conditioning: dict = None, + callback: callable = None, + normals_sequence: list = None, + img_callback: callable = None, + quantize_x0: bool = False, + eta: float = 0., + mask: torch.Tensor = None, + x0: torch.Tensor = None, + temperature: float = 1., + noise_dropout: float = 0., + score_corrector: callable = None, + corrector_kwargs: dict = None, + verbose: bool = True, + x_T: torch.Tensor = None, + log_every_t: int = 100, + unconditional_guidance_scale: float = 1., + unconditional_conditioning: torch.Tensor = None, + **kwargs) -> tuple: + """ + Perform sampling using the DPM Solver. + + Args: + S (int): Number of steps. + batch_size (int): Batch size for sampling. + shape (tuple): Shape of the samples (C, H, W). + conditioning (dict, optional): Conditioning information. Defaults to None. + callback (callable, optional): Callback function. Defaults to None. + normals_sequence (list, optional): Sequence of normals. Defaults to None. + img_callback (callable, optional): Image callback function. Defaults to None. + quantize_x0 (bool, optional): Flag for quantizing x0. Defaults to False. + eta (float, optional): Eta parameter. Defaults to 0.. + mask (torch.Tensor, optional): Mask tensor. Defaults to None. + x0 (torch.Tensor, optional): Initial x0 tensor. Defaults to None. + temperature (float, optional): Temperature parameter. Defaults to 1.. + noise_dropout (float, optional): Noise dropout parameter. Defaults to 0.. + score_corrector (callable, optional): Score corrector. Defaults to None. + corrector_kwargs (dict, optional): Keyword arguments for the score corrector. Defaults to None. + verbose (bool, optional): Verbose flag. Defaults to True. + x_T (torch.Tensor, optional): Initial x_T tensor. Defaults to None. + log_every_t (int, optional): Log interval. Defaults to 100. + unconditional_guidance_scale (float, optional): Guidance scale for unconditional sampling. Defaults to 1.. + unconditional_conditioning (torch.Tensor, optional): Conditioning tensor for unconditional sampling. Defaults to None. + **kwargs: Additional keyword arguments. + + Returns: + tuple: Sampled tensor and additional information. + """ if conditioning is not None: if isinstance(conditioning, dict): ctmp = conditioning[list(conditioning.keys())[0]] @@ -94,3 +134,4 @@ class DPMSolverSampler(object): lower_order_final=True) return x.to(device), None +