Fix wrong dim setting in SpatialTransformer

The data dim:
in_ch -> inner_dim -> in_ch

So the proj_out should be inner_dim -> in_ch
This commit is contained in:
Kohaku-Blueleaf 2022-12-16 11:18:46 +08:00
parent d55bcd4d31
commit 5c0b4cb97d

View file

@ -315,7 +315,7 @@ class SpatialTransformer(nn.Module):
stride=1,
padding=0))
else:
self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
self.use_linear = use_linear
def forward(self, x, context=None):