PyTorch: SpectralNorm cannot be applied as custom nn.Module's weight is None

I tried to use torch.nn.utils.spectral_norm on the SeparableConv2d layer.

I have used the _initialize_weights function to initialize weights for the layer but spectral_norm still raises errors as it cannot be applied when the parameter weight is None.

Here is the code

from torch.nn.utils import spectral_norm

def _initialize_weights(layers):
    for layer in layers:
        init.kaiming_normal_(layer.weight)
        layer.weight.data *= 0.1
        if layer.bias is not None:
            nn.init.constant_(layer.bias, 0)


class SeparableConv2d(nn.Module):
    def __init__(
        self, in_channels, out_channels, kernel_size, stride=1, padding=1, bias=True
    ):
        super(SeparableConv2d, self).__init__()
        self.depthwise = nn.Conv2d(
            in_channels,
            in_channels,
            kernel_size=kernel_size,
            stride=stride,
            groups=in_channels,
            bias=bias,
            padding=padding,
        )
        self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=bias)

        _initialize_weights([self.depthwise, self.pointwise])

    def forward(self, x):
        return self.pointwise(self.depthwise(x))


class Discriminator(nn.Module):
    def __init__(self, num_in_ch=3, num_feat=64):
        super(Discriminator, self).__init__()

        self.conv0 = SeparableConv2d(
            num_in_ch, num_feat, kernel_size=3, stride=1, padding=1
        )

        self.conv1 = spectral_norm(
            SeparableConv2d(num_feat, num_feat * 2, 3, 2, 1, bias=False)
        )
        self.conv2 = spectral_norm(
            SeparableConv2d(num_feat * 2, num_feat * 4, 3, 2, 1, bias=False)
        )

        # Center
        self.conv3 = spectral_norm(
            SeparableConv2d(num_feat * 4, num_feat * 8, 3, 2, 1, bias=False)
        )

        self.gating = spectral_norm(
            SeparableConv2d(num_feat * 8, num_feat * 4, 1, 1, 1, bias=False)
        )

        # attention Blocks
        self.attn_1 = AttentionNexusBlock(
            x_channels=num_feat * 4, g_channels=num_feat * 4
        )
        self.attn_2 = AttentionNexusBlock(
            x_channels=num_feat * 2, g_channels=num_feat * 4
        )
        self.attn_3 = AttentionNexusBlock(x_channels=num_feat, g_channels=num_feat * 4)

        # Cat
        self.cat_1 = ConcatenationNexusBlock(dim_in=num_feat * 8, dim_out=num_feat * 4)
        self.cat_2 = ConcatenationNexusBlock(dim_in=num_feat * 4, dim_out=num_feat * 2)
        self.cat_3 = ConcatenationNexusBlock(dim_in=num_feat * 2, dim_out=num_feat)

        # upsample
        self.conv4 = spectral_norm(
            SeparableConv2d(num_feat * 8, num_feat * 4, 3, 1, 1, bias=False)
        )
        self.conv5 = spectral_norm(
            SeparableConv2d(num_feat * 4, num_feat * 2, 3, 1, 1, bias=False)
        )
        self.conv6 = spectral_norm(
            SeparableConv2d(num_feat * 2, num_feat, 3, 1, 1, bias=False)
        )

        # extra
        self.conv7 = spectral_norm(
            SeparableConv2d(num_feat, num_feat, 3, 1, 1, bias=False)
        )
        self.conv8 = spectral_norm(
            SeparableConv2d(num_feat, num_feat, 3, 1, 1, bias=False)
        )
        self.conv9 = SeparableConv2d(num_feat, 1, 3, 1, 1)

    def forward(self, x):
        x0 = F.leaky_relu(self.conv0(x), negative_slope=0.2, inplace=True)
        x1 = F.leaky_relu(self.conv1(x0), negative_slope=0.2, inplace=True)
        x2 = F.leaky_relu(self.conv2(x1), negative_slope=0.2, inplace=True)
        x3 = F.leaky_relu(self.conv3(x2), negative_slope=0.2, inplace=True)

        gated = F.leaky_relu(self.gating(x3), negative_slope=0.2, inplace=True)

        # Attention
        attn1 = self.attn_1(x2, gated)
        attn2 = self.attn_2(x1, gated)
        attn3 = self.attn_3(x0, gated)

        # upsample
        x3 = self.cat_1(attn1, x3)
        x4 = F.leaky_relu(self.conv4(x3), negative_slope=0.2, inplace=True)
        x4 = self.cat_2(attn2, x4)
        x5 = F.leaky_relu(self.conv5(x4), negative_slope=0.2, inplace=True)
        x5 = self.cat_3(attn3, x5)
        x6 = F.leaky_relu(self.conv6(x5), negative_slope=0.2, inplace=True)

        # extra
        out = F.leaky_relu(self.conv7(x6), negative_slope=0.2, inplace=True)
        out = F.leaky_relu(self.conv8(out), negative_slope=0.2, inplace=True)
        out = self.conv9(out)

        return out

And, here is the error

KeyError                                  Traceback (most recent call last)
/home/khoa/dev/GAN/summarize_archs.ipynb Cell 6' in <cell line: 1>()
----> 1 summary(ND(), (1, 3, 256, 256))

File ~/dev/GAN/networks/gan/models.py:96, in Discriminator.__init__(self, num_in_ch, num_feat)
     90 super(Discriminator, self).__init__()
     92 self.conv0 = SeparableConv2d(
     93     num_in_ch, num_feat, kernel_size=3, stride=1, padding=1
     94 )
---> 96 self.conv1 = spectral_norm(
     97     SeparableConv2d(num_feat, num_feat * 2, 3, 2, 1, bias=False)
     98 )
     99 self.conv2 = spectral_norm(
    100     SeparableConv2d(num_feat * 2, num_feat * 4, 3, 2, 1, bias=False)
    101 )
    103 # Center

File ~/anaconda3/envs/nexus/lib/python3.9/site-packages/torch/nn/utils/spectral_norm.py:280, in spectral_norm(module, name, n_power_iterations, eps, dim)
    278     else:
    279         dim = 0
--> 280 SpectralNorm.apply(module, name, n_power_iterations, dim, eps)
    281 return module

File ~/anaconda3/envs/nexus/lib/python3.9/site-packages/torch/nn/utils/spectral_norm.py:122, in SpectralNorm.apply(module, name, n_power_iterations, dim, eps)
    118         raise RuntimeError("Cannot register two spectral_norm hooks on "
    119                            "the same parameter {}".format(name))
    121 fn = SpectralNorm(name, n_power_iterations, dim, eps)
--> 122 weight = module._parameters[name]
    123 if weight is None:
    124     raise ValueError(f'`SpectralNorm` cannot be applied as parameter `{name}` is None')

KeyError: 'weight'

I’m totally getting lost, can you please tell me how to fix it!!

Shouldn’t this be layer.weight = init.kaiming_normal_(layer.weight)? As the init methods return a Tensor it makes sense to assign it to that layer. (The same applies for layer.bias too)

Also, you shouldn’t be modifying .data on a Tensor, do something like layer.weight = 0.1*layer.weight.

And, I’d recommend setting inplace=False rather than inplace=True, the speed-up is minimal and usually leads to more errors!

1 Like