Fix SpatialTransformer proj_out in attention.py

This commit is contained in:
JakoError 2023-03-15 16:56:18 +08:00
parent fc1488421a
commit 388d673583

View file

@ -315,7 +315,7 @@ class SpatialTransformer(nn.Module):
stride=1, stride=1,
padding=0)) padding=0))
else: else:
self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
self.use_linear = use_linear self.use_linear = use_linear
def forward(self, x, context=None): def forward(self, x, context=None):