I’m training a GAN model, and want to apply some weight initialization to the conv layers and batchnorm layers in the Generator. However, I notice a clear difference when doing the following two things: 1. put the model to GPU, 2. apply the weight init. The order of doing these two things will affect the model output results (i.e. generate different images from Generator G). Please check the attached code.
If you run the code, you will see 5 rows of images.
1: G is on cpu, apply weight init, let G generate a row of images.
2: G is on cpu, apply weight init, move to gpu, generate a row of images.
3: G is on cpu, move to gpu, generate a row of images.
4: G is on cpu, move to gpu, apply weight init, generate a row of images.
5. G is on cpu, move to gpu, generate a row of images.
Since I reset the random seed every time, so I think each time G should have the same weight and is initialized in the same way. However, that is not the case. Row 1 and row 2 have the same result, which is consistent as expected. However, rows 3,4,5 are different from 1,2, which is not expected.
I then have some questions and assumption:
- Comparing row (1,2) and row 4, the weights_init want to set the weight to have std of 0.02, but G1’s weight has std of 0.0099, which means the weights_init is not properly applied, but on row 4 is properly set to 0.02.
- Comparing row 4 and row 5, G4 applied the weight init and G5 is not, then why they generate the same images? I also print the weights, it shows they indeed have different weights, then why they generate the same image?
- Since G4 and G1 both applied weight init, shouldn’t they look the same? I assume it is due to the difference in how random numbers are generated (one on cpu and one on gpu). Is it the reason for this one?
My main conclusions are:
- If apply weight_init for G on cpu, the weight is indeed changed, but not as expected (0.02 vs 0.0099)
- If apply weight_init for G after G is on gpu, then this won’t change anything, although the model weights are indeed changed.
import torch
from torchvision.utils import save_image
from models import Generator, weights_init
import torch
import torch.nn as nn
from torch.nn.utils import spectral_norm
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
m.weight.data.normal_(0.0, 0.02)
elif classname.find('BatchNorm') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
class GLU(nn.Module):
def __init__(self):
super(GLU, self).__init__()
def forward(self, x):
nc = x.size(1)
assert nc % 2 == 0, 'channels dont divide 2!'
nc = int(nc/2)
return x[:, :nc] * torch.sigmoid(x[:, nc:])
def conv1x1(in_planes, out_planes, bias=False):
"1x1 convolution with padding"
return spectral_norm(nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1,
padding=0, bias=bias))
def conv3x3(in_planes, out_planes):
"3x3 convolution with padding"
return spectral_norm(nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=1,
padding=1, bias=False))
# Upsale the spatial size by a factor of 2
def upBlock(in_planes, out_planes):
block = nn.Sequential(
nn.Upsample(scale_factor=2, mode='nearest'),
conv3x3(in_planes, out_planes*2),
nn.BatchNorm2d(out_planes*2),
GLU())
return block
class Generator(nn.Module):
"""docstring for CAN_Generator"""
def __init__(self, ngf=64, nz=100, nc=3):
super(Generator, self).__init__()
self.init = nn.Sequential(
spectral_norm(nn.ConvTranspose2d(nz, ngf*32, 4, 1, 0)),
nn.BatchNorm2d(ngf * 32),
GLU())
self.main = nn.Sequential(
upBlock(ngf * 16, ngf * 8),
upBlock(ngf * 8, ngf * 4),
upBlock(ngf * 4, ngf * 2),
upBlock(ngf * 2, ngf * 2),
upBlock(ngf * 2, ngf * 1),
upBlock(ngf * 1, ngf // 2),
)
self.to_256 = nn.Sequential(
conv1x1(ngf // 2, 3),
nn.Tanh()
)
def forward(self, input):
feat_256 = self.init(input.view(input.size(0), -1, 1, 1)).view(input.size(0), -1, 4, 4)
feat_256 = self.main(feat_256)
img_256 = self.to_256(feat_256)
return img_256
device = torch.device('cuda:1')
ngf=64
nz=100
torch.manual_seed(8888)
fixed_noise = torch.FloatTensor(8, nz, 1, 1).normal_(0, 1).to(device)
torch.manual_seed(8888)
torch.cuda.manual_seed(8888)
netG1 = Generator(ngf=ngf, nz=nz)
netG1.apply(weights_init)
gimg1 = netG1(fixed_noise.cpu())
print(netG1.main[0][1].weight.std(), netG1.main[0][1].weight.mean(), netG1.main[0][1].weight[0][0])
torch.manual_seed(8888)
torch.cuda.manual_seed(8888)
netG2 = Generator(ngf=ngf, nz=nz)
netG2.apply(weights_init)
netG2.to(device)
gimg2 = netG2(fixed_noise)
torch.manual_seed(8888)
torch.cuda.manual_seed(8888)
netG3 = Generator(ngf=ngf, nz=nz)
netG3.to(device)
gimg3 = netG3(fixed_noise)
torch.manual_seed(8888)
torch.cuda.manual_seed(8888)
netG4 = Generator(ngf=ngf, nz=nz).to(device)
netG4.apply(weights_init)
print(netG4.main[0][1].weight.std(), netG4.main[0][1].weight.mean(), netG4.main[0][1].weight[0][0])
print(netG4.main[0][2].weight.std(), netG4.main[0][2].weight.mean())
gimg4 = netG4(fixed_noise)
torch.manual_seed(8888)
torch.cuda.manual_seed(8888)
netG5 = Generator(ngf=ngf, nz=nz).to(device)
print(netG5.main[0][1].weight.std(), netG5.main[0][1].weight.mean())
print(netG5.main[0][2].weight.std(), netG5.main[0][2].weight.mean())
gimg5 = netG5(fixed_noise)
print( torch.nn.functional.mse_loss(gimg4, gimg5) )
save_image( torch.cat([gimg1.to(device), gimg2, gimg3, gimg4, gimg5]).add(1).mul(0.5),
'test.jpg', nrow=8)