Spectral Norm in eval mode

I’m currently implementing SAGAN in pytorch, which uses the new nn.utils.spectral_norm (and batchnorm) for normalization. The results appear well during sampling in training, however when I load a snapshot and set the network to eval mode, I get complete garbage as output. If I don’t set eval mode, the first sample after loading is normal, but subsequent samples are distorted, even if I set requires_grad to false on all the parameters. Any ideas how I can fix this?

I’m curious about this phenomenon.

Though I don’t have any concrete idea, but I think BatchNorm might have an effect a bit due to the population (running) statistics used in eval mode.

I think there are several parts:

  • I haven’t implemented SAGAN yet, but SNGAN doesn’t use eval mode at all when training the generator. I think that is mostly because of batch norm which behaves quite differently between training/eval.
  • I would expect that for a fully trained network, the impact of additional updates of the spectral_norm don’t do that much. You could see what the largest eigenvalue of the weight with spectral normalization . Also spectral normalization affects only the discriminator (at least in SNGAN)?
  • I’m surprised that in training mode the behaviour of batch norm changes: While the statistics are updated, they are not used in training. The weight/bias parameters on the other hand should not be updated without calling optimizer steps.
  • When you are just sampling, you don’t call the optimizer’s step, do you?

You could experimentally try to get closer to the matter:

  • If you just set the batch norms to eval, does it output garbage, too?
  • Does generating samples involve spectral normalization?
  • If you set everything to eval except the batch norms, does it work? (Then you don’t need to worry about spectral norm.)

Best regards

Thomas

1 Like

With BatchNorm in eval mode and everything else in training mode, the result is even weirder: The first run also still works, but instead of being distorted, subsequent samples are complete garbage.

Setting all the spectral norm to eval, however, produced garbage from the first sample.

I have spectral norm in both the discriminator and generator. I don’t even have an optimizer, as this just loads the weights trained separately and generates samples.

By the way, here are all relevant parts of my model:

ysize = 45
zsize = 100
isize = 128

batchsize = 16

gen_up_blocks = 4
gen_inp_planes = 32
gen_resblocks = 3
gen_filters = 256
gen_self_attn = [0, 1, 0, 0, 0]

class Generator(nn.Module):
    def __init__(self, isize, ysize, zsize, gen_up_blocks, gen_inp_planes,
                 gen_resblocks, gen_filters, gen_self_attn):
        super(Generator, self).__init__()
        sidedim = isize // (gen_up_blocks ** 2)
        self.sidedim = sidedim
        self.latent = nn.Linear(zsize + ysize, (sidedim ** 2) * gen_inp_planes)
        main = nn.Sequential()

        main.add_module('pad_init', nn.ReflectionPad2d(1))
        main.add_module('conv_init', nn.utils.spectral_norm(nn.Conv2d(gen_inp_planes, gen_filters, 3))())

        if gen_self_attn[0]:
            main.add_module('sa_{}'.format(0), Self_Attn(gen_filters))


        for i in range(gen_up_blocks):
            mod = GenUpBlock(gen_filters, gen_filters // 2, gen_resblocks)

            main.add_module('up_{}'.format(i), mod)

            gen_filters //= 2

            if gen_self_attn[i + 1]:
                main.add_module('sa_{}'.format(i + 1), Self_Attn(gen_filters))

        main.add_module('padding', nn.ReflectionPad2d(3))
        main.add_module('conv_final', nn.utils.spectral_norm(nn.Conv2d(gen_filters, 3, 7))())
        main.add_module('sigmoid', nn.Sigmoid())
        self.gen_inp_planes = gen_inp_planes
        self.main = main

    def forward(self, z, y):
        z = self.latent(torch.cat((z, y), 1))
        z = z.view(-1, self.gen_inp_planes, self.sidedim, self.sidedim)
        z = self.main(z)

        return z

class ResBlock(nn.Module):
    def __init__(self, channels, bn):
        super(ResBlock, self).__init__()

        self.bn = bn

        self.pad1 = nn.ReflectionPad2d(1)
        self.conv1 = nn.utils.spectral_norm(nn.Conv2d(channels, channels, 3))()
        if bn:
            self.bn1 = nn.BatchNorm2d(channels)
        self.pad2 = nn.ReflectionPad2d(1)
        self.conv2 = nn.utils.spectral_norm(nn.Conv2d(channels, channels, 3))()
        if bn:
            self.bn2 = nn.BatchNorm2d(channels)

    def forward(self, x):
        residual = x
        out = self.pad1(x)
        out = self.conv1(out)
        if self.bn:
            out = self.bn1(out)
        out = nn.functional.leaky_relu(out, negative_slope=0.2)
        out = self.pad2(out)
        out = self.conv2(out)
        if self.bn:
            out = self.bn2(out)
        out += residual
        out = nn.functional.leaky_relu(out, negative_slope=0.2)
        return out

class GenUpBlock(nn.Module):
    def __init__(self, in_channels, out_channels, resblocks):
        super(GenUpBlock, self).__init__()

        self.res = nn.Sequential()
        for i in range(resblocks):
            self.res.add_module('res_{}'.format(i), ResBlock(in_channels, True))
        self.pad = nn.ReflectionPad2d(1)
        self.conv = nn.utils.spectral_norm(nn.Conv2d(in_channels, out_channels * 4, 3))()
        self.shuf = nn.PixelShuffle(2)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.LeakyReLU(0.2)

    def forward(self, x):
        skip = x
        x = self.res(x)
        x += skip
        x = self.pad(x)
        x = self.conv(x)
        x = self.shuf(x)
        x = self.bn(x)
        x = self.relu(x)
        return x

# (Modified) From https://github.com/heykeetae/Self-Attention-GAN/blob/master/sagan_models.py
class Self_Attn(nn.Module):
    """ Self attention Layer"""

    def __init__(self, in_dim):
        super(Self_Attn, self).__init__()

        self.kernel_size = 1

        self.query_conv = nn.utils.spectral_norm(nn.Conv2d(
            in_channels=in_dim, out_channels=in_dim // 8, kernel_size=self.kernel_size))
        self.key_conv = nn.utils.spectral_norm(nn.Conv2d(
            in_channels=in_dim, out_channels=in_dim // 8, kernel_size=self.kernel_size))
        self.value_conv = nn.utils.spectral_norm(nn.Conv2d(
            in_channels=in_dim, out_channels=in_dim, kernel_size=self.kernel_size))
        self.gamma = nn.Parameter(torch.zeros(1))

        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        """
            inputs :
                x : input feature maps( B X C X W X H)
            returns :
                out : self attention value + input feature
                attention: B X N X N (N is Width*Height)
        """
        m_batchsize, C, width, height = x.size()

        px = x

        proj_query = self.query_conv(px).view(
            m_batchsize, -1, width * height).permute(0, 2, 1)  # B X CX(N)
        proj_key = self.key_conv(px).view(
            m_batchsize, -1, width * height)  # B X C x (*W*H)
        energy = torch.bmm(proj_query, proj_key)  # transpose check
        attention = self.softmax(energy)  # BX (N) X (N)
        proj_value = self.value_conv(px).view(
            m_batchsize, -1, width * height)  # B X C X N

        out = torch.bmm(proj_value, attention.permute(0, 2, 1))
        out = out.view(m_batchsize, C, width, height)

        out = self.gamma * out + x
        return out

Hmm. On of those days, have to plug it into my sngan workbook to reproduce the failure.

Is this with master or an older version? I think we improved spectral norm at least before 0.4.1, but possibly also after.
On master, I think we expect spectral norm to give the same weight in eval mode as the last time in non-eval mode. Would you be able to check that (by saving the weight.)

Best regards

Thomas

1 Like

I’m using torch==0.4.1 from PyPI.

As expected, when in eval mode, none of the weights change (but the output is garbage). When in train mode with no_grad() and requires_grad_(False), these are the weights that change:

module.main.conv_init.weight
module.main.conv_init.weight_u
module.main.up_0.res.res_0.conv1.weight
module.main.up_0.res.res_0.conv1.weight_u
module.main.up_0.res.res_0.bn1.running_mean
module.main.up_0.res.res_0.bn1.running_var
module.main.up_0.res.res_0.bn1.num_batches_tracked
module.main.up_0.res.res_0.conv2.weight
module.main.up_0.res.res_0.conv2.weight_u
module.main.up_0.res.res_0.bn2.running_mean
module.main.up_0.res.res_0.bn2.running_var
module.main.up_0.res.res_0.bn2.num_batches_tracked
module.main.up_0.res.res_1.conv1.weight
module.main.up_0.res.res_1.conv1.weight_u
module.main.up_0.res.res_1.bn1.running_mean
module.main.up_0.res.res_1.bn1.running_var
module.main.up_0.res.res_1.bn1.num_batches_tracked
module.main.up_0.res.res_1.conv2.weight
module.main.up_0.res.res_1.conv2.weight_u
module.main.up_0.res.res_1.bn2.running_mean
module.main.up_0.res.res_1.bn2.running_var
module.main.up_0.res.res_1.bn2.num_batches_tracked
module.main.up_0.res.res_2.conv1.weight
module.main.up_0.res.res_2.conv1.weight_u
module.main.up_0.res.res_2.bn1.running_mean
module.main.up_0.res.res_2.bn1.running_var
module.main.up_0.res.res_2.bn1.num_batches_tracked
module.main.up_0.res.res_2.conv2.weight
module.main.up_0.res.res_2.conv2.weight_u
module.main.up_0.res.res_2.bn2.running_mean
module.main.up_0.res.res_2.bn2.running_var
module.main.up_0.res.res_2.bn2.num_batches_tracked
module.main.up_0.conv.weight
module.main.up_0.conv.weight_u
module.main.up_0.bn.running_mean
module.main.up_0.bn.running_var
module.main.up_0.bn.num_batches_tracked
module.main.sa_1.query_conv.weight
module.main.sa_1.query_conv.weight_u
module.main.sa_1.key_conv.weight
module.main.sa_1.key_conv.weight_u
module.main.sa_1.value_conv.weight
module.main.sa_1.value_conv.weight_u
module.main.up_1.res.res_0.conv1.weight
module.main.up_1.res.res_0.conv1.weight_u
module.main.up_1.res.res_0.bn1.running_mean
module.main.up_1.res.res_0.bn1.running_var
module.main.up_1.res.res_0.bn1.num_batches_tracked
module.main.up_1.res.res_0.conv2.weight
module.main.up_1.res.res_0.conv2.weight_u
module.main.up_1.res.res_0.bn2.running_mean
module.main.up_1.res.res_0.bn2.running_var
module.main.up_1.res.res_0.bn2.num_batches_tracked
module.main.up_1.res.res_1.conv1.weight
module.main.up_1.res.res_1.conv1.weight_u
module.main.up_1.res.res_1.bn1.running_mean
module.main.up_1.res.res_1.bn1.running_var
module.main.up_1.res.res_1.bn1.num_batches_tracked
module.main.up_1.res.res_1.conv2.weight
module.main.up_1.res.res_1.conv2.weight_u
module.main.up_1.res.res_1.bn2.running_mean
module.main.up_1.res.res_1.bn2.running_var
module.main.up_1.res.res_1.bn2.num_batches_tracked
module.main.up_1.res.res_2.conv1.weight
module.main.up_1.res.res_2.conv1.weight_u
module.main.up_1.res.res_2.bn1.running_mean
module.main.up_1.res.res_2.bn1.running_var
module.main.up_1.res.res_2.bn1.num_batches_tracked
module.main.up_1.res.res_2.conv2.weight
module.main.up_1.res.res_2.conv2.weight_u
module.main.up_1.res.res_2.bn2.running_mean
module.main.up_1.res.res_2.bn2.running_var
module.main.up_1.res.res_2.bn2.num_batches_tracked
module.main.up_1.conv.weight
module.main.up_1.conv.weight_u
module.main.up_1.bn.running_mean
module.main.up_1.bn.running_var
module.main.up_1.bn.num_batches_tracked
module.main.up_2.res.res_0.conv1.weight
module.main.up_2.res.res_0.conv1.weight_u
module.main.up_2.res.res_0.bn1.running_mean
module.main.up_2.res.res_0.bn1.running_var
module.main.up_2.res.res_0.bn1.num_batches_tracked
module.main.up_2.res.res_0.conv2.weight
module.main.up_2.res.res_0.conv2.weight_u
module.main.up_2.res.res_0.bn2.running_mean
module.main.up_2.res.res_0.bn2.running_var
module.main.up_2.res.res_0.bn2.num_batches_tracked
module.main.up_2.res.res_1.conv1.weight
module.main.up_2.res.res_1.conv1.weight_u
module.main.up_2.res.res_1.bn1.running_mean
module.main.up_2.res.res_1.bn1.running_var
module.main.up_2.res.res_1.bn1.num_batches_tracked
module.main.up_2.res.res_1.conv2.weight
module.main.up_2.res.res_1.conv2.weight_u
module.main.up_2.res.res_1.bn2.running_mean
module.main.up_2.res.res_1.bn2.running_var
module.main.up_2.res.res_1.bn2.num_batches_tracked
module.main.up_2.res.res_2.conv1.weight
module.main.up_2.res.res_2.conv1.weight_u
module.main.up_2.res.res_2.bn1.running_mean
module.main.up_2.res.res_2.bn1.running_var
module.main.up_2.res.res_2.bn1.num_batches_tracked
module.main.up_2.res.res_2.conv2.weight
module.main.up_2.res.res_2.conv2.weight_u
module.main.up_2.res.res_2.bn2.running_mean
module.main.up_2.res.res_2.bn2.running_var
module.main.up_2.res.res_2.bn2.num_batches_tracked
module.main.up_2.conv.weight
module.main.up_2.conv.weight_u
module.main.up_2.bn.running_mean
module.main.up_2.bn.running_var
module.main.up_2.bn.num_batches_tracked
module.main.up_3.res.res_0.conv1.weight
module.main.up_3.res.res_0.conv1.weight_u
module.main.up_3.res.res_0.bn1.running_mean
module.main.up_3.res.res_0.bn1.running_var
module.main.up_3.res.res_0.bn1.num_batches_tracked
module.main.up_3.res.res_0.conv2.weight
module.main.up_3.res.res_0.conv2.weight_u
module.main.up_3.res.res_0.bn2.running_mean
module.main.up_3.res.res_0.bn2.running_var
module.main.up_3.res.res_0.bn2.num_batches_tracked
module.main.up_3.res.res_1.conv1.weight
module.main.up_3.res.res_1.conv1.weight_u
module.main.up_3.res.res_1.bn1.running_mean
module.main.up_3.res.res_1.bn1.running_var
module.main.up_3.res.res_1.bn1.num_batches_tracked
module.main.up_3.res.res_1.conv2.weight
module.main.up_3.res.res_1.conv2.weight_u
module.main.up_3.res.res_1.bn2.running_mean
module.main.up_3.res.res_1.bn2.running_var
module.main.up_3.res.res_1.bn2.num_batches_tracked
module.main.up_3.res.res_2.conv1.weight
module.main.up_3.res.res_2.conv1.weight_u
module.main.up_3.res.res_2.bn1.running_mean
module.main.up_3.res.res_2.bn1.running_var
module.main.up_3.res.res_2.bn1.num_batches_tracked
module.main.up_3.res.res_2.conv2.weight
module.main.up_3.res.res_2.conv2.weight_u
module.main.up_3.res.res_2.bn2.running_mean
module.main.up_3.res.res_2.bn2.running_var
module.main.up_3.res.res_2.bn2.num_batches_tracked
module.main.up_3.conv.weight
module.main.up_3.conv.weight_u
module.main.up_3.bn.running_mean
module.main.up_3.bn.running_var
module.main.up_3.bn.num_batches_tracked
module.main.conv_final.weight
module.main.conv_final.weight_u

Comparison of eval and train:

netG = nn.DataParallel(Generator(isize, ysize, zsize, gen_up_blocks, gen_inp_planes, gen_resblocks, gen_filters, gen_self_attn))

checkpoint = torch.load('561-checkpoint-383.pth.tar', map_location=dev)
netG.load_state_dict(checkpoint['G_state_dict'])
sd1 = copy.deepcopy(netG.state_dict())

f1 = netG(z, y)

checkpoint = torch.load('561-checkpoint-383.pth.tar', map_location=dev)
netG.load_state_dict(checkpoint['G_state_dict'])
netG.eval()
sd2 = copy.deepcopy(netG.state_dict())

f2 = netG(z, y)

print(torch.all(torch.eq(f1, f2)))
for k, v in sd1.items():
    if not torch.all(torch.eq(sd2[k], v)):
        print(k)

Output:

tensor(0, dtype=torch.uint8)
module.main.conv_init.weight_orig
module.main.up_0.res.res_0.conv1.weight_orig
module.main.up_0.res.res_0.conv2.weight_orig
module.main.up_0.res.res_1.conv1.weight_orig
module.main.up_0.res.res_1.conv2.weight_orig
module.main.up_0.res.res_2.conv1.weight_orig
module.main.up_0.res.res_2.conv2.weight_orig
module.main.up_0.conv.weight_orig
module.main.sa_1.query_conv.weight_orig
module.main.sa_1.key_conv.weight_orig
module.main.sa_1.value_conv.weight_orig
module.main.up_1.res.res_0.conv1.weight_orig
module.main.up_1.res.res_0.conv2.weight_orig
module.main.up_1.res.res_1.conv1.weight_orig
module.main.up_1.res.res_1.conv2.weight_orig
module.main.up_1.res.res_2.conv1.weight_orig
module.main.up_1.res.res_2.conv2.weight_orig
module.main.up_1.conv.weight_orig
module.main.up_2.res.res_0.conv1.weight_orig
module.main.up_2.res.res_0.conv2.weight_orig
module.main.up_2.res.res_1.conv1.weight_orig
module.main.up_2.res.res_1.conv2.weight_orig
module.main.up_2.res.res_2.conv1.weight_orig
module.main.up_2.res.res_2.conv2.weight_orig
module.main.up_2.conv.weight_orig
module.main.up_3.res.res_0.conv1.weight_orig
module.main.up_3.res.res_0.conv2.weight_orig
module.main.up_3.res.res_1.conv1.weight_orig
module.main.up_3.res.res_1.conv2.weight_orig
module.main.up_3.res.res_2.conv1.weight_orig
module.main.up_3.res.res_2.conv2.weight_orig
module.main.up_3.conv.weight_orig
module.main.conv_final.weight_orig

btw, spectral norm with data parallel is broken on 0.4.1

1 Like

Is it fixed on master?

Also, does this mean I have to retrain my model, or can I just remove DataParallel in my eval script?

Unfortunately it is broken in training mode. So your checkpoints won’t work… Sorry about it.

It’s not fixed on master, but I will submit a fix next week.

1 Like

Thanks for letting me know!

Hi Simon, is the data_parallel problem on spectral norm fixed?

A fix is at https://github.com/pytorch/pytorch/pull/12671

1 Like