diff --git a/ldm/models/diffusion/ddpm.py b/ldm/models/diffusion/ddpm.py index 6090212..7bd6bb6 100644 --- a/ldm/models/diffusion/ddpm.py +++ b/ldm/models/diffusion/ddpm.py @@ -900,6 +900,7 @@ class LatentDiffusion(DDPM): loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3]) loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()}) + self.logvar = self.logvar.to(self.device) logvar_t = self.logvar[t].to(self.device) loss = loss_simple / torch.exp(logvar_t) + logvar_t # loss = loss_simple / torch.exp(self.logvar) + self.logvar