I’m currently implementing SAGAN in pytorch, which uses the new nn.utils.spectral_norm (and batchnorm) for normalization. The results appear well during sampling in training, however when I load a snapshot and set the network to eval mode, I get complete garbage as output. If I don’t set eval mode, the first sample after loading is normal, but subsequent samples are distorted, even if I set requires_grad to false on all the parameters. Any ideas how I can fix this?
I’m curious about this phenomenon.
Though I don’t have any concrete idea, but I think BatchNorm might have an effect a bit due to the population (running) statistics used in eval mode.
I think there are several parts:
- I haven’t implemented SAGAN yet, but SNGAN doesn’t use eval mode at all when training the generator. I think that is mostly because of batch norm which behaves quite differently between training/eval.
- I would expect that for a fully trained network, the impact of additional updates of the spectral_norm don’t do that much. You could see what the largest eigenvalue of the weight with spectral normalization . Also spectral normalization affects only the discriminator (at least in SNGAN)?
- I’m surprised that in training mode the behaviour of batch norm changes: While the statistics are updated, they are not used in training. The weight/bias parameters on the other hand should not be updated without calling optimizer steps.
- When you are just sampling, you don’t call the optimizer’s step, do you?
You could experimentally try to get closer to the matter:
- If you just set the batch norms to eval, does it output garbage, too?
- Does generating samples involve spectral normalization?
- If you set everything to eval except the batch norms, does it work? (Then you don’t need to worry about spectral norm.)
Best regards
Thomas
With BatchNorm in eval mode and everything else in training mode, the result is even weirder: The first run also still works, but instead of being distorted, subsequent samples are complete garbage.
Setting all the spectral norm to eval, however, produced garbage from the first sample.
I have spectral norm in both the discriminator and generator. I don’t even have an optimizer, as this just loads the weights trained separately and generates samples.
By the way, here are all relevant parts of my model:
ysize = 45
zsize = 100
isize = 128
batchsize = 16
gen_up_blocks = 4
gen_inp_planes = 32
gen_resblocks = 3
gen_filters = 256
gen_self_attn = [0, 1, 0, 0, 0]
class Generator(nn.Module):
def __init__(self, isize, ysize, zsize, gen_up_blocks, gen_inp_planes,
gen_resblocks, gen_filters, gen_self_attn):
super(Generator, self).__init__()
sidedim = isize // (gen_up_blocks ** 2)
self.sidedim = sidedim
self.latent = nn.Linear(zsize + ysize, (sidedim ** 2) * gen_inp_planes)
main = nn.Sequential()
main.add_module('pad_init', nn.ReflectionPad2d(1))
main.add_module('conv_init', nn.utils.spectral_norm(nn.Conv2d(gen_inp_planes, gen_filters, 3))())
if gen_self_attn[0]:
main.add_module('sa_{}'.format(0), Self_Attn(gen_filters))
for i in range(gen_up_blocks):
mod = GenUpBlock(gen_filters, gen_filters // 2, gen_resblocks)
main.add_module('up_{}'.format(i), mod)
gen_filters //= 2
if gen_self_attn[i + 1]:
main.add_module('sa_{}'.format(i + 1), Self_Attn(gen_filters))
main.add_module('padding', nn.ReflectionPad2d(3))
main.add_module('conv_final', nn.utils.spectral_norm(nn.Conv2d(gen_filters, 3, 7))())
main.add_module('sigmoid', nn.Sigmoid())
self.gen_inp_planes = gen_inp_planes
self.main = main
def forward(self, z, y):
z = self.latent(torch.cat((z, y), 1))
z = z.view(-1, self.gen_inp_planes, self.sidedim, self.sidedim)
z = self.main(z)
return z
class ResBlock(nn.Module):
def __init__(self, channels, bn):
super(ResBlock, self).__init__()
self.bn = bn
self.pad1 = nn.ReflectionPad2d(1)
self.conv1 = nn.utils.spectral_norm(nn.Conv2d(channels, channels, 3))()
if bn:
self.bn1 = nn.BatchNorm2d(channels)
self.pad2 = nn.ReflectionPad2d(1)
self.conv2 = nn.utils.spectral_norm(nn.Conv2d(channels, channels, 3))()
if bn:
self.bn2 = nn.BatchNorm2d(channels)
def forward(self, x):
residual = x
out = self.pad1(x)
out = self.conv1(out)
if self.bn:
out = self.bn1(out)
out = nn.functional.leaky_relu(out, negative_slope=0.2)
out = self.pad2(out)
out = self.conv2(out)
if self.bn:
out = self.bn2(out)
out += residual
out = nn.functional.leaky_relu(out, negative_slope=0.2)
return out
class GenUpBlock(nn.Module):
def __init__(self, in_channels, out_channels, resblocks):
super(GenUpBlock, self).__init__()
self.res = nn.Sequential()
for i in range(resblocks):
self.res.add_module('res_{}'.format(i), ResBlock(in_channels, True))
self.pad = nn.ReflectionPad2d(1)
self.conv = nn.utils.spectral_norm(nn.Conv2d(in_channels, out_channels * 4, 3))()
self.shuf = nn.PixelShuffle(2)
self.bn = nn.BatchNorm2d(out_channels)
self.relu = nn.LeakyReLU(0.2)
def forward(self, x):
skip = x
x = self.res(x)
x += skip
x = self.pad(x)
x = self.conv(x)
x = self.shuf(x)
x = self.bn(x)
x = self.relu(x)
return x
# (Modified) From https://github.com/heykeetae/Self-Attention-GAN/blob/master/sagan_models.py
class Self_Attn(nn.Module):
""" Self attention Layer"""
def __init__(self, in_dim):
super(Self_Attn, self).__init__()
self.kernel_size = 1
self.query_conv = nn.utils.spectral_norm(nn.Conv2d(
in_channels=in_dim, out_channels=in_dim // 8, kernel_size=self.kernel_size))
self.key_conv = nn.utils.spectral_norm(nn.Conv2d(
in_channels=in_dim, out_channels=in_dim // 8, kernel_size=self.kernel_size))
self.value_conv = nn.utils.spectral_norm(nn.Conv2d(
in_channels=in_dim, out_channels=in_dim, kernel_size=self.kernel_size))
self.gamma = nn.Parameter(torch.zeros(1))
self.softmax = nn.Softmax(dim=-1)
def forward(self, x):
"""
inputs :
x : input feature maps( B X C X W X H)
returns :
out : self attention value + input feature
attention: B X N X N (N is Width*Height)
"""
m_batchsize, C, width, height = x.size()
px = x
proj_query = self.query_conv(px).view(
m_batchsize, -1, width * height).permute(0, 2, 1) # B X CX(N)
proj_key = self.key_conv(px).view(
m_batchsize, -1, width * height) # B X C x (*W*H)
energy = torch.bmm(proj_query, proj_key) # transpose check
attention = self.softmax(energy) # BX (N) X (N)
proj_value = self.value_conv(px).view(
m_batchsize, -1, width * height) # B X C X N
out = torch.bmm(proj_value, attention.permute(0, 2, 1))
out = out.view(m_batchsize, C, width, height)
out = self.gamma * out + x
return out
Hmm. On of those days, have to plug it into my sngan workbook to reproduce the failure.
Is this with master or an older version? I think we improved spectral norm at least before 0.4.1, but possibly also after.
On master, I think we expect spectral norm to give the same weight in eval mode as the last time in non-eval mode. Would you be able to check that (by saving the weight.)
Best regards
Thomas
I’m using torch==0.4.1 from PyPI.
As expected, when in eval mode, none of the weights change (but the output is garbage). When in train mode with no_grad()
and requires_grad_(False)
, these are the weights that change:
module.main.conv_init.weight
module.main.conv_init.weight_u
module.main.up_0.res.res_0.conv1.weight
module.main.up_0.res.res_0.conv1.weight_u
module.main.up_0.res.res_0.bn1.running_mean
module.main.up_0.res.res_0.bn1.running_var
module.main.up_0.res.res_0.bn1.num_batches_tracked
module.main.up_0.res.res_0.conv2.weight
module.main.up_0.res.res_0.conv2.weight_u
module.main.up_0.res.res_0.bn2.running_mean
module.main.up_0.res.res_0.bn2.running_var
module.main.up_0.res.res_0.bn2.num_batches_tracked
module.main.up_0.res.res_1.conv1.weight
module.main.up_0.res.res_1.conv1.weight_u
module.main.up_0.res.res_1.bn1.running_mean
module.main.up_0.res.res_1.bn1.running_var
module.main.up_0.res.res_1.bn1.num_batches_tracked
module.main.up_0.res.res_1.conv2.weight
module.main.up_0.res.res_1.conv2.weight_u
module.main.up_0.res.res_1.bn2.running_mean
module.main.up_0.res.res_1.bn2.running_var
module.main.up_0.res.res_1.bn2.num_batches_tracked
module.main.up_0.res.res_2.conv1.weight
module.main.up_0.res.res_2.conv1.weight_u
module.main.up_0.res.res_2.bn1.running_mean
module.main.up_0.res.res_2.bn1.running_var
module.main.up_0.res.res_2.bn1.num_batches_tracked
module.main.up_0.res.res_2.conv2.weight
module.main.up_0.res.res_2.conv2.weight_u
module.main.up_0.res.res_2.bn2.running_mean
module.main.up_0.res.res_2.bn2.running_var
module.main.up_0.res.res_2.bn2.num_batches_tracked
module.main.up_0.conv.weight
module.main.up_0.conv.weight_u
module.main.up_0.bn.running_mean
module.main.up_0.bn.running_var
module.main.up_0.bn.num_batches_tracked
module.main.sa_1.query_conv.weight
module.main.sa_1.query_conv.weight_u
module.main.sa_1.key_conv.weight
module.main.sa_1.key_conv.weight_u
module.main.sa_1.value_conv.weight
module.main.sa_1.value_conv.weight_u
module.main.up_1.res.res_0.conv1.weight
module.main.up_1.res.res_0.conv1.weight_u
module.main.up_1.res.res_0.bn1.running_mean
module.main.up_1.res.res_0.bn1.running_var
module.main.up_1.res.res_0.bn1.num_batches_tracked
module.main.up_1.res.res_0.conv2.weight
module.main.up_1.res.res_0.conv2.weight_u
module.main.up_1.res.res_0.bn2.running_mean
module.main.up_1.res.res_0.bn2.running_var
module.main.up_1.res.res_0.bn2.num_batches_tracked
module.main.up_1.res.res_1.conv1.weight
module.main.up_1.res.res_1.conv1.weight_u
module.main.up_1.res.res_1.bn1.running_mean
module.main.up_1.res.res_1.bn1.running_var
module.main.up_1.res.res_1.bn1.num_batches_tracked
module.main.up_1.res.res_1.conv2.weight
module.main.up_1.res.res_1.conv2.weight_u
module.main.up_1.res.res_1.bn2.running_mean
module.main.up_1.res.res_1.bn2.running_var
module.main.up_1.res.res_1.bn2.num_batches_tracked
module.main.up_1.res.res_2.conv1.weight
module.main.up_1.res.res_2.conv1.weight_u
module.main.up_1.res.res_2.bn1.running_mean
module.main.up_1.res.res_2.bn1.running_var
module.main.up_1.res.res_2.bn1.num_batches_tracked
module.main.up_1.res.res_2.conv2.weight
module.main.up_1.res.res_2.conv2.weight_u
module.main.up_1.res.res_2.bn2.running_mean
module.main.up_1.res.res_2.bn2.running_var
module.main.up_1.res.res_2.bn2.num_batches_tracked
module.main.up_1.conv.weight
module.main.up_1.conv.weight_u
module.main.up_1.bn.running_mean
module.main.up_1.bn.running_var
module.main.up_1.bn.num_batches_tracked
module.main.up_2.res.res_0.conv1.weight
module.main.up_2.res.res_0.conv1.weight_u
module.main.up_2.res.res_0.bn1.running_mean
module.main.up_2.res.res_0.bn1.running_var
module.main.up_2.res.res_0.bn1.num_batches_tracked
module.main.up_2.res.res_0.conv2.weight
module.main.up_2.res.res_0.conv2.weight_u
module.main.up_2.res.res_0.bn2.running_mean
module.main.up_2.res.res_0.bn2.running_var
module.main.up_2.res.res_0.bn2.num_batches_tracked
module.main.up_2.res.res_1.conv1.weight
module.main.up_2.res.res_1.conv1.weight_u
module.main.up_2.res.res_1.bn1.running_mean
module.main.up_2.res.res_1.bn1.running_var
module.main.up_2.res.res_1.bn1.num_batches_tracked
module.main.up_2.res.res_1.conv2.weight
module.main.up_2.res.res_1.conv2.weight_u
module.main.up_2.res.res_1.bn2.running_mean
module.main.up_2.res.res_1.bn2.running_var
module.main.up_2.res.res_1.bn2.num_batches_tracked
module.main.up_2.res.res_2.conv1.weight
module.main.up_2.res.res_2.conv1.weight_u
module.main.up_2.res.res_2.bn1.running_mean
module.main.up_2.res.res_2.bn1.running_var
module.main.up_2.res.res_2.bn1.num_batches_tracked
module.main.up_2.res.res_2.conv2.weight
module.main.up_2.res.res_2.conv2.weight_u
module.main.up_2.res.res_2.bn2.running_mean
module.main.up_2.res.res_2.bn2.running_var
module.main.up_2.res.res_2.bn2.num_batches_tracked
module.main.up_2.conv.weight
module.main.up_2.conv.weight_u
module.main.up_2.bn.running_mean
module.main.up_2.bn.running_var
module.main.up_2.bn.num_batches_tracked
module.main.up_3.res.res_0.conv1.weight
module.main.up_3.res.res_0.conv1.weight_u
module.main.up_3.res.res_0.bn1.running_mean
module.main.up_3.res.res_0.bn1.running_var
module.main.up_3.res.res_0.bn1.num_batches_tracked
module.main.up_3.res.res_0.conv2.weight
module.main.up_3.res.res_0.conv2.weight_u
module.main.up_3.res.res_0.bn2.running_mean
module.main.up_3.res.res_0.bn2.running_var
module.main.up_3.res.res_0.bn2.num_batches_tracked
module.main.up_3.res.res_1.conv1.weight
module.main.up_3.res.res_1.conv1.weight_u
module.main.up_3.res.res_1.bn1.running_mean
module.main.up_3.res.res_1.bn1.running_var
module.main.up_3.res.res_1.bn1.num_batches_tracked
module.main.up_3.res.res_1.conv2.weight
module.main.up_3.res.res_1.conv2.weight_u
module.main.up_3.res.res_1.bn2.running_mean
module.main.up_3.res.res_1.bn2.running_var
module.main.up_3.res.res_1.bn2.num_batches_tracked
module.main.up_3.res.res_2.conv1.weight
module.main.up_3.res.res_2.conv1.weight_u
module.main.up_3.res.res_2.bn1.running_mean
module.main.up_3.res.res_2.bn1.running_var
module.main.up_3.res.res_2.bn1.num_batches_tracked
module.main.up_3.res.res_2.conv2.weight
module.main.up_3.res.res_2.conv2.weight_u
module.main.up_3.res.res_2.bn2.running_mean
module.main.up_3.res.res_2.bn2.running_var
module.main.up_3.res.res_2.bn2.num_batches_tracked
module.main.up_3.conv.weight
module.main.up_3.conv.weight_u
module.main.up_3.bn.running_mean
module.main.up_3.bn.running_var
module.main.up_3.bn.num_batches_tracked
module.main.conv_final.weight
module.main.conv_final.weight_u
Comparison of eval and train:
netG = nn.DataParallel(Generator(isize, ysize, zsize, gen_up_blocks, gen_inp_planes, gen_resblocks, gen_filters, gen_self_attn))
checkpoint = torch.load('561-checkpoint-383.pth.tar', map_location=dev)
netG.load_state_dict(checkpoint['G_state_dict'])
sd1 = copy.deepcopy(netG.state_dict())
f1 = netG(z, y)
checkpoint = torch.load('561-checkpoint-383.pth.tar', map_location=dev)
netG.load_state_dict(checkpoint['G_state_dict'])
netG.eval()
sd2 = copy.deepcopy(netG.state_dict())
f2 = netG(z, y)
print(torch.all(torch.eq(f1, f2)))
for k, v in sd1.items():
if not torch.all(torch.eq(sd2[k], v)):
print(k)
Output:
tensor(0, dtype=torch.uint8)
module.main.conv_init.weight_orig
module.main.up_0.res.res_0.conv1.weight_orig
module.main.up_0.res.res_0.conv2.weight_orig
module.main.up_0.res.res_1.conv1.weight_orig
module.main.up_0.res.res_1.conv2.weight_orig
module.main.up_0.res.res_2.conv1.weight_orig
module.main.up_0.res.res_2.conv2.weight_orig
module.main.up_0.conv.weight_orig
module.main.sa_1.query_conv.weight_orig
module.main.sa_1.key_conv.weight_orig
module.main.sa_1.value_conv.weight_orig
module.main.up_1.res.res_0.conv1.weight_orig
module.main.up_1.res.res_0.conv2.weight_orig
module.main.up_1.res.res_1.conv1.weight_orig
module.main.up_1.res.res_1.conv2.weight_orig
module.main.up_1.res.res_2.conv1.weight_orig
module.main.up_1.res.res_2.conv2.weight_orig
module.main.up_1.conv.weight_orig
module.main.up_2.res.res_0.conv1.weight_orig
module.main.up_2.res.res_0.conv2.weight_orig
module.main.up_2.res.res_1.conv1.weight_orig
module.main.up_2.res.res_1.conv2.weight_orig
module.main.up_2.res.res_2.conv1.weight_orig
module.main.up_2.res.res_2.conv2.weight_orig
module.main.up_2.conv.weight_orig
module.main.up_3.res.res_0.conv1.weight_orig
module.main.up_3.res.res_0.conv2.weight_orig
module.main.up_3.res.res_1.conv1.weight_orig
module.main.up_3.res.res_1.conv2.weight_orig
module.main.up_3.res.res_2.conv1.weight_orig
module.main.up_3.res.res_2.conv2.weight_orig
module.main.up_3.conv.weight_orig
module.main.conv_final.weight_orig
btw, spectral norm with data parallel is broken on 0.4.1
Is it fixed on master?
Also, does this mean I have to retrain my model, or can I just remove DataParallel in my eval script?
Unfortunately it is broken in training mode. So your checkpoints won’t work… Sorry about it.
It’s not fixed on master, but I will submit a fix next week.
Thanks for letting me know!
Hi Simon, is the data_parallel problem on spectral norm fixed?