diff --git a/ldm/modules/attention.py b/ldm/modules/attention.py index 509cd87..d3b978f 100644 --- a/ldm/modules/attention.py +++ b/ldm/modules/attention.py @@ -315,7 +315,7 @@ class SpatialTransformer(nn.Module): stride=1, padding=0)) else: - self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) + self.proj_out = zero_module(nn.Linear(inner_dim, in_channels)) self.use_linear = use_linear def forward(self, x, context=None):