nn.Module cannot properly apply a weight initialization

I’m training a GAN model, and want to apply some weight initialization to the conv layers and batchnorm layers in the Generator. However, I notice a clear difference when doing the following two things: 1. put the model to GPU, 2. apply the weight init. The order of doing these two things will affect the model output results (i.e. generate different images from Generator G). Please check the attached code.

If you run the code, you will see 5 rows of images.
1: G is on cpu, apply weight init, let G generate a row of images.
2: G is on cpu, apply weight init, move to gpu, generate a row of images.
3: G is on cpu, move to gpu, generate a row of images.
4: G is on cpu, move to gpu, apply weight init, generate a row of images.
5. G is on cpu, move to gpu, generate a row of images.

Since I reset the random seed every time, so I think each time G should have the same weight and is initialized in the same way. However, that is not the case. Row 1 and row 2 have the same result, which is consistent as expected. However, rows 3,4,5 are different from 1,2, which is not expected.

I then have some questions and assumption:

  1. Comparing row (1,2) and row 4, the weights_init want to set the weight to have std of 0.02, but G1’s weight has std of 0.0099, which means the weights_init is not properly applied, but on row 4 is properly set to 0.02.
  2. Comparing row 4 and row 5, G4 applied the weight init and G5 is not, then why they generate the same images? I also print the weights, it shows they indeed have different weights, then why they generate the same image?
  3. Since G4 and G1 both applied weight init, shouldn’t they look the same? I assume it is due to the difference in how random numbers are generated (one on cpu and one on gpu). Is it the reason for this one?

My main conclusions are:

  1. If apply weight_init for G on cpu, the weight is indeed changed, but not as expected (0.02 vs 0.0099)
  2. If apply weight_init for G after G is on gpu, then this won’t change anything, although the model weights are indeed changed.
import torch
from torchvision.utils import save_image
from models import Generator, weights_init
import torch
import torch.nn as nn
from torch.nn.utils import spectral_norm


def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)


class GLU(nn.Module):
    def __init__(self):
        super(GLU, self).__init__()

    def forward(self, x):
        nc = x.size(1)
        assert nc % 2 == 0, 'channels dont divide 2!'
        nc = int(nc/2)
        return x[:, :nc] * torch.sigmoid(x[:, nc:])


def conv1x1(in_planes, out_planes, bias=False):
    "1x1 convolution with padding"
    return spectral_norm(nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1,
                     padding=0, bias=bias))


def conv3x3(in_planes, out_planes):
    "3x3 convolution with padding"
    return spectral_norm(nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=1,
                     padding=1, bias=False))


# Upsale the spatial size by a factor of 2
def upBlock(in_planes, out_planes):
    block = nn.Sequential(
        nn.Upsample(scale_factor=2, mode='nearest'),
        conv3x3(in_planes, out_planes*2),
        nn.BatchNorm2d(out_planes*2),
        GLU())
    return block


class Generator(nn.Module):
    """docstring for CAN_Generator"""

    def __init__(self, ngf=64, nz=100, nc=3):
        super(Generator, self).__init__()

        self.init = nn.Sequential(
            spectral_norm(nn.ConvTranspose2d(nz, ngf*32, 4, 1, 0)),
            nn.BatchNorm2d(ngf * 32),
            GLU())
        
        self.main = nn.Sequential(
            upBlock(ngf * 16, ngf * 8),
            upBlock(ngf * 8, ngf * 4),
            upBlock(ngf * 4, ngf * 2),
            upBlock(ngf * 2, ngf * 2),
            upBlock(ngf * 2, ngf * 1),
            upBlock(ngf * 1, ngf // 2),
            )
        self.to_256 = nn.Sequential(
            conv1x1(ngf // 2, 3),
            nn.Tanh()
            )

    def forward(self, input):
        feat_256 = self.init(input.view(input.size(0), -1, 1, 1)).view(input.size(0), -1, 4, 4)
        feat_256 = self.main(feat_256)
        img_256 = self.to_256(feat_256)
        return img_256

device = torch.device('cuda:1')
ngf=64
nz=100

torch.manual_seed(8888)

fixed_noise = torch.FloatTensor(8, nz, 1, 1).normal_(0, 1).to(device)

torch.manual_seed(8888)
torch.cuda.manual_seed(8888)
netG1 = Generator(ngf=ngf, nz=nz)
netG1.apply(weights_init)
gimg1 = netG1(fixed_noise.cpu())
print(netG1.main[0][1].weight.std(), netG1.main[0][1].weight.mean(), netG1.main[0][1].weight[0][0])


torch.manual_seed(8888)
torch.cuda.manual_seed(8888)
netG2 = Generator(ngf=ngf, nz=nz)
netG2.apply(weights_init)
netG2.to(device)
gimg2 = netG2(fixed_noise)

torch.manual_seed(8888)
torch.cuda.manual_seed(8888)
netG3 = Generator(ngf=ngf, nz=nz)
netG3.to(device)
gimg3 = netG3(fixed_noise)

torch.manual_seed(8888)
torch.cuda.manual_seed(8888)
netG4 = Generator(ngf=ngf, nz=nz).to(device)
netG4.apply(weights_init)
print(netG4.main[0][1].weight.std(), netG4.main[0][1].weight.mean(), netG4.main[0][1].weight[0][0])
print(netG4.main[0][2].weight.std(), netG4.main[0][2].weight.mean())
gimg4 = netG4(fixed_noise)

torch.manual_seed(8888)
torch.cuda.manual_seed(8888)
netG5 = Generator(ngf=ngf, nz=nz).to(device)
print(netG5.main[0][1].weight.std(), netG5.main[0][1].weight.mean())
print(netG5.main[0][2].weight.std(), netG5.main[0][2].weight.mean())

gimg5 = netG5(fixed_noise)
print( torch.nn.functional.mse_loss(gimg4, gimg5) )
save_image( torch.cat([gimg1.to(device), gimg2, gimg3, gimg4, gimg5]).add(1).mul(0.5), 
                    'test.jpg', nrow=8)
            

I think that might be expected.
In 1 and 2 you are applying the weight init on the CPU, while 3 and 5 are not using the weight init at all, so these results should differ. (Also 3 and 5 seem to be the same use case, no?)
In 4 you are applying the weight init on the GPU, which will use another pseudorandom number generator as seen here:


torch.manual_seed(2809)
print(torch.randn(1))
> tensor([-2.0748])

torch.manual_seed(2809)
print(torch.randn(1))
> tensor([-2.0748])

torch.manual_seed(2809)
print(torch.randn(1, device='cuda'))
> tensor([0.5603], device='cuda:0')

torch.manual_seed(2809)
print(torch.randn(1, device='cuda'))
> tensor([0.5603], device='cuda:0')

Thanks for the reply.
My new question will be, while 3 and 5 do not use weight_init and 4 used weight_init, why the generated images look the same (which means the weight_init is not worked when the model weights are on GPU).

Basically, please check my “question” and “conclusion” sections which are the problems. Appreciate your help!

I cannot reproduce it and get a std of the expected ~0.02 for all weight parameters after running:

netG1.apply(weights_init)
for name, param in netG1.named_parameters():
    print(name, param.std())

How are you checking, if the output changed? Note that a visual comparison might seem to look equal, but e.g. if your model is saturating after the initialization, the output might look equal but would be calculated in a different manner.

Thanks for the reply.
For the first one, I see the reason I got std of 0.0099 instead of 0.02, which is because I have spectral_nom on conv layer weights and I print the std after a forward pass (which triggers the spectral norm to modify the weight).

However, for the second point, which weight_init does not work on GPU, I still cannot reason it properly. Please check this new example, which I print out the MSE between generated images:

import torch
from torchvision.utils import save_image
import torch
import torch.nn as nn
from torch.nn.utils import spectral_norm


def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)


class GLU(nn.Module):
    def __init__(self):
        super(GLU, self).__init__()

    def forward(self, x):
        nc = x.size(1)
        assert nc % 2 == 0, 'channels dont divide 2!'
        nc = int(nc/2)
        return x[:, :nc] * torch.sigmoid(x[:, nc:])


def conv1x1(in_planes, out_planes, bias=False):
    "1x1 convolution with padding"
    return spectral_norm(nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1,
                     padding=0, bias=bias))


def conv3x3(in_planes, out_planes):
    "3x3 convolution with padding"
    return spectral_norm(nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=1,
                     padding=1, bias=False))


# Upsale the spatial size by a factor of 2
def upBlock(in_planes, out_planes):
    block = nn.Sequential(
        nn.Upsample(scale_factor=2, mode='nearest'),
        conv3x3(in_planes, out_planes*2),
        nn.BatchNorm2d(out_planes*2),
        GLU())
    return block


class Generator(nn.Module):
    """docstring for CAN_Generator"""

    def __init__(self, ngf=64, nz=100, nc=3):
        super(Generator, self).__init__()

        self.init = nn.Sequential(
            spectral_norm(nn.ConvTranspose2d(nz, ngf*32, 4, 1, 0)),
            nn.BatchNorm2d(ngf * 32),
            GLU())
        
        self.main = nn.Sequential(
            upBlock(ngf * 16, ngf * 8),
            upBlock(ngf * 8, ngf * 4),
            upBlock(ngf * 4, ngf * 2),
            upBlock(ngf * 2, ngf * 2),
            upBlock(ngf * 2, ngf * 1),
            upBlock(ngf * 1, ngf // 2),
            )
        self.to_256 = nn.Sequential(
            conv1x1(ngf // 2, 3),
            nn.Tanh()
            )

    def forward(self, input):
        feat_256 = self.init(input.view(input.size(0), -1, 1, 1)).view(input.size(0), -1, 4, 4)
        feat_256 = self.main(feat_256)
        img_256 = self.to_256(feat_256)
        return img_256

device = torch.device('cuda:0')
ngf=64
nz=100

torch.manual_seed(8888)

fixed_noise = torch.FloatTensor(8, nz, 1, 1).normal_(0, 1).to(device)

torch.manual_seed(8888)
torch.cuda.manual_seed(8888)
netG1 = Generator(ngf=ngf, nz=nz)
netG1.apply(weights_init)
gimg1 = netG1(fixed_noise.cpu())

torch.manual_seed(8888)
torch.cuda.manual_seed(8888)
netG2 = Generator(ngf=ngf, nz=nz)
gimg2 = netG2(fixed_noise.cpu())


print( torch.nn.functional.mse_loss(gimg1, gimg2) )
save_image( torch.cat([gimg1, gimg2]).add(1).mul(0.5), 'test_cpu.jpg', nrow=8)


torch.manual_seed(8888)
torch.cuda.manual_seed(8888)
netG4 = Generator(ngf=ngf, nz=nz).to(device)
netG4.apply(weights_init)
gimg4 = netG4(fixed_noise)

torch.manual_seed(8888)
torch.cuda.manual_seed(8888)
netG5 = Generator(ngf=ngf, nz=nz).to(device)
gimg5 = netG5(fixed_noise)

print( torch.nn.functional.mse_loss(gimg4, gimg5) )
save_image( torch.cat([gimg4, gimg5]).add(1).mul(0.5), 'test_gpu.jpg', nrow=8)

For image 1 and 2, I run everything on CPU, image 1 is with model weight_init and image 2 is without init, the generated images looks clearly different, and their mse is big (0.37). In contrast, image 4 and 5 are for models on GPU, the generated images looks very same and their mse is only 0.001 (means almost the same).

Thanks for your time, please check my follow up. Basically about the issue that: there is a clear difference between applying weight_init on “model on CPU” and “model on GPU”.

Can anyone please check the problem?