mirror of
https://github.com/Stability-AI/stablediffusion.git
synced 2024-12-22 15:44:58 +00:00
Add FP32 fallback support on ldm/modules/diffusionmodules/openaimodel.py
This tries to execute interpolate with FP32 if it failed. Background is that on some environment such as Mx chip MacOS devices, we get error as follows: ``` " File "ldm/modules/diffusionmodules/openaimodel.py", line 115, in forward x = F.interpolate(x, scale_factor=2, mode="nearest") ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "torch/nn/functional.py", line 3931, in interpolate return torch._C._nn.upsample_nearest2d(input, output_size, scale_factors) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ RuntimeError: "upsample_nearest2d_channels_last" not implemented for 'Half' ``` Therefore this commit adds the FP32 fallback execution to solve it.
This commit is contained in:
parent
cf1d67a6fd
commit
ca92a1406b
1 changed files with 7 additions and 1 deletions
|
@ -112,7 +112,13 @@ class Upsample(nn.Module):
|
|||
x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
|
||||
)
|
||||
else:
|
||||
x = F.interpolate(x, scale_factor=2, mode="nearest")
|
||||
try:
|
||||
x = F.interpolate(x, scale_factor=2, mode="nearest")
|
||||
except RuntimeError as e:
|
||||
if "not implemented for" in str(e) and "Half" in str(e):
|
||||
x = F.interpolate(x.to(th.float32), scale_factor=2, mode="nearest").to(x.dtype)
|
||||
else:
|
||||
print(f"An unexpected RuntimeError occurred: {str(e)}")
|
||||
if self.use_conv:
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
|
Loading…
Reference in a new issue