Hello,
The following code:
class mod_demod(nn.Module):
def __init__(self, latent_dim, out_channels):
super().__init__()
self.map_layer = EqualLRLinear(latent_dim, out_channels)
self.bias = nn.Parameter(torch.zeros(1, out_channels, 1, 1))
self.out_channels = out_channels
def forward(self, w, conv_weight, batch):
# conv_weight is the weights of the current convolutional layer
s = self.map_layer(w)
s = s.view(w.size(0), -1, 1, 1, 1)
# Add bias and 1 (as per StyleGAN2)
s = s + 1 + self.bias.view(1, -1, 1, 1, 1)
# Modulation
weight = conv_weight * s
# Demodulation
demod = torch.rsqrt(weight.pow(2).sum([2,3,4])+ 1e-8)
weight = weight * demod.unsqueeze(2).unsqueeze(3).unsqueeze(4)
# Return the weight's of the convolution
return nn.Parameter(weight)
class Conv2d_mod(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, latent_dim, padding=1):
super().__init__()
# The weights for our conv
self.weights = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size))
nn.init.normal_(self.weights)
self.bias = nn.Parameter(torch.zeros(1, out_channels, 1, 1)) # b2 in diagram
self.modulator = mod_demod(latent_dim, out_channels)
self.eq_lr_scale = sqrt(2 / (in_channels * kernel_size ** 2))
self.padding = padding
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.latent_dim = latent_dim
def forward(self, x, style_w):
batch, channels, H, W = x.shape
weights = self.weights * self.eq_lr_scale
# Need to create the modulated and demodulated weights
weights = self.modulator(style_w, weights, batch)
print(f'Channels in x: {channels} | self.in_channels: {self.in_channels} | self.out_channels: {self.out_channels}')
print(weights.shape)
weights = weights.view(batch * self.out_channels, self.in_channels, self.kernel_size, self.kernel_size)
#weights = weights.view(batch * self.out_channels, channels, self.kernel_size, self.kernel_size)
x = x.view(1, batch*channels, H, W)
print(f'x.shape: {x.shape} | weights.shape: {weights.shape}')
out = F.conv2d(x, weights, groups=batch, padding=self.padding)
out = out.view(batch, self.out_channels, out.shape[-2], out.shape[-1])
out += self.bias
return out
Works fine for the input:
conv2d_mod = Conv2d_mod(256, 256, 3, 256)
x = torch.randn(4, 256, 64, 64)
w = torch.randn(4, 256)
But when the input is:
conv2d_mod = Conv2d_mod(256, 256//2, 3, 256)
x = torch.randn(4, 128, 128, 128)
w = torch.randn(4, 256)
It throws the following error:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[183], line 4
2 x = torch.randn(4, 128, 128, 128)
3 w = torch.randn(4, 256)
----> 4 conv2d_mod(x, w).shape
File ~/anaconda3/envs/gpu_use/lib/python3.11/site-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
1551 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1552 else:
-> 1553 return self._call_impl(*args, **kwargs)
File ~/anaconda3/envs/gpu_use/lib/python3.11/site-packages/torch/nn/modules/module.py:1562, in Module._call_impl(self, *args, **kwargs)
1557 # If we don't have any hooks, we want to skip the rest of the logic in
1558 # this function, and just call forward.
1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1560 or _global_backward_pre_hooks or _global_backward_hooks
1561 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562 return forward_call(*args, **kwargs)
1564 try:
1565 result = None
Cell In[181], line 36, in Conv2d_mod.forward(self, x, style_w)
34 x = x.view(1, batch*channels, H, W)
35 print(f'x.shape: {x.shape} | weights.shape: {weights.shape}')
---> 36 out = F.conv2d(x, weights, groups=batch, padding=self.padding)
37 out = out.view(batch, self.out_channels, out.shape[-2], out.shape[-1])
39 out += self.bias
RuntimeError: Given groups=4, weight of size [512, 256, 3, 3], expected input[1, 512, 128, 128] to have 1024 channels, but got 512 channels instead
I am stumped as to what is causing this error, it appears to be due to the in and out channels being different but this is typical for a StyleGAN so it needs to stay this way.
Any ideas on how to fix?