Hi guys,
I was wondering if this forward pass is correct to align the dims of the residual connections:
def forward(self, x):
# print(f"Decoder input: {x.shape}")
x, self_attn = self.seq_attention(x)
# print(f"After seq_attn: {x.shape}")
x = self.activation(self.norm1(self.deconv1(x)))
# print(f"After deconv1: {x.shape}")
x = self.activation(self.norm2(self.deconv2(x)))
# print(f"After deconv2: {x.shape}")
residual_1 = x
x = self.activation(self.norm3(self.deconv3(x)))
# print(f"After deconv3: {x.shape}")
x = self.activation(self.norm4(self.deconv4(x)))
# print(f"After deconv4: {x.shape}")
x = x + F.interpolate(residual_1, size=x.shape[2:], mode='nearest')
# print(f"After residual interpolation 1: {x.shape}")
x = self.final_layer(x)
x = F.interpolate(x, size=self.final_shape, mode='linear', align_corners=False)
x = self.tanh(x)
# print(f"After final transform, interpolate, and tanh: {x.shape}")
return x, self_attn
I would greatly appreciate any comments and potential pros and cons.
Thank you!