Fix wrong dim setting in SpatialTransformer

The data dim:
in_ch -> inner_dim -> in_ch

So the proj_out should be inner_dim -> in_ch
This commit is contained in:
Kohaku-Blueleaf 2022-12-16 11:18:46 +08:00
parent d55bcd4d31
commit 5c0b4cb97d

View file

@ -315,7 +315,7 @@ class SpatialTransformer(nn.Module):
stride=1, stride=1,
padding=0)) padding=0))
else: else:
self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
self.use_linear = use_linear self.use_linear = use_linear
def forward(self, x, context=None): def forward(self, x, context=None):