diff --git a/ldm/modules/diffusionmodules/util.py b/ldm/modules/diffusionmodules/util.py index daf35da..29b6fc6 100644 --- a/ldm/modules/diffusionmodules/util.py +++ b/ldm/modules/diffusionmodules/util.py @@ -170,8 +170,8 @@ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): if not repeat_only: half = dim // 2 freqs = torch.exp( - -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half - ).to(device=timesteps.device) + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=timesteps.device) / half + ) args = timesteps[:, None].float() * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if dim % 2: