StyleGAN2 Demodulation grouped convolution incorrect dimensions

Hello,

The following code:

class mod_demod(nn.Module):
    def __init__(self, latent_dim, out_channels):
        super().__init__()
    
        self.map_layer = EqualLRLinear(latent_dim, out_channels)
        self.bias = nn.Parameter(torch.zeros(1, out_channels, 1, 1))

        self.out_channels = out_channels
    def forward(self, w, conv_weight, batch):
        # conv_weight is the weights of the current convolutional layer
        s = self.map_layer(w)
        s = s.view(w.size(0), -1, 1, 1, 1)
        # Add bias and 1 (as per StyleGAN2)
        s = s + 1 + self.bias.view(1, -1, 1, 1, 1)
        # Modulation
        weight = conv_weight * s

        # Demodulation
        demod = torch.rsqrt(weight.pow(2).sum([2,3,4])+ 1e-8)
        weight = weight * demod.unsqueeze(2).unsqueeze(3).unsqueeze(4)

        # Return the weight's of the convolution
        return nn.Parameter(weight)
        

class Conv2d_mod(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, latent_dim, padding=1):
        super().__init__()

        # The weights for our conv
        self.weights = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size))
        nn.init.normal_(self.weights)

        self.bias = nn.Parameter(torch.zeros(1, out_channels, 1, 1))  # b2 in diagram
        
        self.modulator = mod_demod(latent_dim, out_channels)

        self.eq_lr_scale = sqrt(2 / (in_channels * kernel_size ** 2))

        self.padding = padding
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.latent_dim = latent_dim

    def forward(self, x, style_w):
        batch, channels, H, W = x.shape

        weights = self.weights * self.eq_lr_scale
        # Need to create the modulated and demodulated weights
        weights = self.modulator(style_w, weights, batch)

        print(f'Channels in x: {channels} | self.in_channels: {self.in_channels} | self.out_channels: {self.out_channels}')

        print(weights.shape)
        weights = weights.view(batch * self.out_channels, self.in_channels, self.kernel_size, self.kernel_size)
        #weights = weights.view(batch * self.out_channels, channels, self.kernel_size, self.kernel_size)

        x = x.view(1, batch*channels, H, W)
        print(f'x.shape: {x.shape} | weights.shape: {weights.shape}')
        out = F.conv2d(x, weights, groups=batch, padding=self.padding)
        out = out.view(batch, self.out_channels, out.shape[-2], out.shape[-1])

        out += self.bias
            
        return out

Works fine for the input:
conv2d_mod = Conv2d_mod(256, 256, 3, 256)
x = torch.randn(4, 256, 64, 64)
w = torch.randn(4, 256)

But when the input is:

conv2d_mod = Conv2d_mod(256, 256//2, 3, 256)
x = torch.randn(4, 128, 128, 128)
w = torch.randn(4, 256)

It throws the following error:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[183], line 4
      2 x = torch.randn(4, 128, 128, 128)
      3 w = torch.randn(4, 256)
----> 4 conv2d_mod(x, w).shape

File ~/anaconda3/envs/gpu_use/lib/python3.11/site-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
   1551     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1552 else:
-> 1553     return self._call_impl(*args, **kwargs)

File ~/anaconda3/envs/gpu_use/lib/python3.11/site-packages/torch/nn/modules/module.py:1562, in Module._call_impl(self, *args, **kwargs)
   1557 # If we don't have any hooks, we want to skip the rest of the logic in
   1558 # this function, and just call forward.
   1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1560         or _global_backward_pre_hooks or _global_backward_hooks
   1561         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562     return forward_call(*args, **kwargs)
   1564 try:
   1565     result = None

Cell In[181], line 36, in Conv2d_mod.forward(self, x, style_w)
     34 x = x.view(1, batch*channels, H, W)
     35 print(f'x.shape: {x.shape} | weights.shape: {weights.shape}')
---> 36 out = F.conv2d(x, weights, groups=batch, padding=self.padding)
     37 out = out.view(batch, self.out_channels, out.shape[-2], out.shape[-1])
     39 out += self.bias

RuntimeError: Given groups=4, weight of size [512, 256, 3, 3], expected input[1, 512, 128, 128] to have 1024 channels, but got 512 channels instead

I am stumped as to what is causing this error, it appears to be due to the in and out channels being different but this is typical for a StyleGAN so it needs to stay this way.

Any ideas on how to fix?

I’m not familiar with your code but note that in_channels=256 in Conv2d_mod while the actual input uses only 128 channels. Using:

conv2d_mod = Conv2d_mod(256, 256//2, 3, 256)
x = torch.randn(4, 256, 128, 128)
w = torch.randn(4, 256)
out = conv2d_mod(x, w)

fixes the shape mismatch.

1 Like