Fixes an issue when training embedding

In certain cases (most notably on Colab using Automatic1111's WebUI, but not limited to that example) trying to train an embedding would cause https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/5821#issue-1501645853 this error.
This commit is contained in:
Gazzoo-byte 2022-12-23 21:39:40 +00:00 committed by GitHub
parent d55bcd4d31
commit 48301e389a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -900,6 +900,7 @@ class LatentDiffusion(DDPM):
loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3])
loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()})
self.logvar = self.logvar.to(self.device)
logvar_t = self.logvar[t].to(self.device)
loss = loss_simple / torch.exp(logvar_t) + logvar_t
# loss = loss_simple / torch.exp(self.logvar) + self.logvar