Add FP32 fallback support on ldm/modules/diffusionmodules/openaimodel.py

This tries to execute interpolate with FP32 if it failed.

Background is that
on some environment such as Mx chip MacOS devices, we get error as follows:

```
"      File "ldm/modules/diffusionmodules/openaimodel.py", line 115, in forward
        x = F.interpolate(x, scale_factor=2, mode="nearest")
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "torch/nn/functional.py", line 3931, in interpolate
        return torch._C._nn.upsample_nearest2d(input, output_size, scale_factors)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    RuntimeError: "upsample_nearest2d_channels_last" not implemented for 'Half'
```

Therefore this commit adds the FP32 fallback execution to solve it.
This commit is contained in:
hidenorly 2023-11-22 00:21:19 +09:00
parent cf1d67a6fd
commit ca92a1406b

View file

@ -112,7 +112,13 @@ class Upsample(nn.Module):
x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
) )
else: else:
try:
x = F.interpolate(x, scale_factor=2, mode="nearest") x = F.interpolate(x, scale_factor=2, mode="nearest")
except RuntimeError as e:
if "not implemented for" in str(e) and "Half" in str(e):
x = F.interpolate(x.to(th.float32), scale_factor=2, mode="nearest").to(x.dtype)
else:
print(f"An unexpected RuntimeError occurred: {str(e)}")
if self.use_conv: if self.use_conv:
x = self.conv(x) x = self.conv(x)
return x return x