Added type hints and improved docstrings for DPMSolverSampler class.

This commit is contained in:
Mandlin Sarah 2024-09-04 11:22:02 -07:00
parent cf1d67a6fd
commit 80ca3cd485

View file

@ -1,4 +1,3 @@
"""SAMPLING ONLY."""
import torch import torch
from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver
@ -8,46 +7,87 @@ MODEL_TYPES = {
"v": "v" "v": "v"
} }
class DPMSolverSampler(object): class DPMSolverSampler(object):
def __init__(self, model, device=torch.device("cuda"), **kwargs): def __init__(self, model: torch.nn.Module, device: torch.device = torch.device("cuda"), **kwargs) -> None:
"""
Initialize the DPMSolverSampler.
Args:
model (torch.nn.Module): The model to use for sampling.
device (torch.device, optional): The device to use. Defaults to torch.device("cuda").
**kwargs: Additional keyword arguments.
"""
super().__init__() super().__init__()
self.model = model self.model = model
self.device = device self.device = device
to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device) to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device)
self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod)) self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod))
def register_buffer(self, name, attr): def register_buffer(self, name: str, attr: torch.Tensor) -> None:
if type(attr) == torch.Tensor: """
Register a buffer in the module.
Args:
name (str): The name of the buffer.
attr (torch.Tensor): The tensor to register as a buffer.
"""
if isinstance(attr, torch.Tensor):
if attr.device != self.device: if attr.device != self.device:
attr = attr.to(self.device) attr = attr.to(self.device)
setattr(self, name, attr) setattr(self, name, attr)
@torch.no_grad() @torch.no_grad()
def sample(self, def sample(self,
S, S: int,
batch_size, batch_size: int,
shape, shape: tuple,
conditioning=None, conditioning: dict = None,
callback=None, callback: callable = None,
normals_sequence=None, normals_sequence: list = None,
img_callback=None, img_callback: callable = None,
quantize_x0=False, quantize_x0: bool = False,
eta=0., eta: float = 0.,
mask=None, mask: torch.Tensor = None,
x0=None, x0: torch.Tensor = None,
temperature=1., temperature: float = 1.,
noise_dropout=0., noise_dropout: float = 0.,
score_corrector=None, score_corrector: callable = None,
corrector_kwargs=None, corrector_kwargs: dict = None,
verbose=True, verbose: bool = True,
x_T=None, x_T: torch.Tensor = None,
log_every_t=100, log_every_t: int = 100,
unconditional_guidance_scale=1., unconditional_guidance_scale: float = 1.,
unconditional_conditioning=None, unconditional_conditioning: torch.Tensor = None,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... **kwargs) -> tuple:
**kwargs """
): Perform sampling using the DPM Solver.
Args:
S (int): Number of steps.
batch_size (int): Batch size for sampling.
shape (tuple): Shape of the samples (C, H, W).
conditioning (dict, optional): Conditioning information. Defaults to None.
callback (callable, optional): Callback function. Defaults to None.
normals_sequence (list, optional): Sequence of normals. Defaults to None.
img_callback (callable, optional): Image callback function. Defaults to None.
quantize_x0 (bool, optional): Flag for quantizing x0. Defaults to False.
eta (float, optional): Eta parameter. Defaults to 0..
mask (torch.Tensor, optional): Mask tensor. Defaults to None.
x0 (torch.Tensor, optional): Initial x0 tensor. Defaults to None.
temperature (float, optional): Temperature parameter. Defaults to 1..
noise_dropout (float, optional): Noise dropout parameter. Defaults to 0..
score_corrector (callable, optional): Score corrector. Defaults to None.
corrector_kwargs (dict, optional): Keyword arguments for the score corrector. Defaults to None.
verbose (bool, optional): Verbose flag. Defaults to True.
x_T (torch.Tensor, optional): Initial x_T tensor. Defaults to None.
log_every_t (int, optional): Log interval. Defaults to 100.
unconditional_guidance_scale (float, optional): Guidance scale for unconditional sampling. Defaults to 1..
unconditional_conditioning (torch.Tensor, optional): Conditioning tensor for unconditional sampling. Defaults to None.
**kwargs: Additional keyword arguments.
Returns:
tuple: Sampled tensor and additional information.
"""
if conditioning is not None: if conditioning is not None:
if isinstance(conditioning, dict): if isinstance(conditioning, dict):
ctmp = conditioning[list(conditioning.keys())[0]] ctmp = conditioning[list(conditioning.keys())[0]]
@ -94,3 +134,4 @@ class DPMSolverSampler(object):
lower_order_final=True) lower_order_final=True)
return x.to(device), None return x.to(device), None