Add gaussian noise to parameters while training

I tried to add gaussian noise to the parameters using the code below but the network won’t converge. Any though why? I used cifar10 dataset with lr=0.001

 import torch.nn as nn                                                                                                                                                                                                                                                       
 import torch.nn.functional as F                                                                                                                                                                                                                                             
 import torch                                                                                                                                                                                                                                                                
 __all__ = ['simplenet_cifar']                                                                                                                                                                                                                                               
                                                                                                                                                                                                                                                                            
                                                                                                                                                                                                                                                                          
 38 class Simplenet(nn.Module):                                                                                                                                                                                                                                                 
 37     def __init__(self):                                                                                                                                                                                                                                                     
 36         super(Simplenet, self).__init__()                                                                                                                                                                                                                                   
 35         self.conv1 = nn.Conv2d(3, 6, 5)                                                                                                                                                                                                                                     
 34         self.relu_conv1 = nn.ReLU()                                                                                                                                                                                                                                         
 33         self.pool1 = nn.MaxPool2d(2, 2)                                                                                                                                                                                                                                     
 32         self.conv2 = nn.Conv2d(6, 16, 5)                                                                                                                                                                                                                                    
 31         self.relu_conv2 = nn.ReLU()                                                                                                                                                                                                                                         
 30         self.pool2 = nn.MaxPool2d(2, 2)                                                                                                                                                                                                                                     
 29         self.fc1 = nn.Linear(16 * 5 * 5, 120)                                                                                                                                                                                                                               
 28         self.relu_fc1 = nn.ReLU()                                                                                                                                                                                                                                           
 27         self.fc2 = nn.Linear(120, 84)                                                                                                                                                                                                                                       
 26         self.relu_fc2 = nn.ReLU()                                                                                                                                                                                                                                           
 25         self.fc3 = nn.Linear(84, 10)                                                                                                                                                                                                                                        
 24                                                                                                                                                                                                                                                                             
 23         self.noise_conv1 = torch.randn(nn.Parameter(self.conv1.weight).size())*0.6 + 0                                                                                                                                                                                      
 22         self.noise_conv2 = torch.randn(nn.Parameter(self.conv2.weight).size())*0.6 + 0                                                                                                                                                                                      
 21                                                                                                                                                                                                                                                                             
 20     def forward(self, x):                                                                                                                                                                                                                                                   
 19         x = self.conv1(x)                                                                                                                                                                                                                                                   
 18         self.conv1.weight =  add_noise(nn.Parameter(self.conv1.weight), self.noise_conv1)                                                                                                                                                                                   
 17         x = self.pool1(self.relu_conv1(x))                                                                                                                                                                                                                                  
 16         x = self.conv2(x)                                                                                                                                                                                                                                                   
 15         self.conv2.weight =  add_noise(nn.Parameter(self.conv2.weight), self.noise_conv2)                                                                                                                                                                                   
 14         x = self.pool2(self.relu_conv2(x))                                                                                                                                                                                                                                  
 13         # x = self.pool1(self.relu_conv1(self.conv1(x)))                                                                                                                                                                                                                    
 12         # x = self.pool2(self.relu_conv2(self.conv2(x)))                                                                                                                                                                                                                    
 11         x = x.view(-1, 16 * 5 * 5)                                                                                                                                                                                                                          
 10         x = self.relu_fc1(self.fc1(x))                                                                                                                                                                                                                                      
  9         x = self.relu_fc2(self.fc2(x))                                                                                                                                                                                                                                      
  8         x = self.fc3(x)                                                                                                                                                                                                                                                     
  7         return x                                                                                                                                                                                                                                                            
  6                                                                                                                                                                                                                                                                             
  5 def add_noise(weights, noise):                                                                                                                                                                                                                                              
  4     with torch.no_grad():                                                                                                                                                                                                                                                   
  3         weight_noise = nn.Parameter(weights + noise.to("cuda"))                                                                                                                                                                                                                                                                                                                                                                                                                                     
  1     return weight_noise

In your current code snippet you are recreating the .weight parameters as new nn.Parameters, which won’t be updated, as they are not passed to the optimizer.
You could add the noise inplace to the parameters, but would also have to add it before these parameters are used.
This might work:

class Simplenet(nn.Module):                                                                                                                                                                                                                                                 
    def __init__(self):                                                                                                                                                                                                                                                     
        super(Simplenet, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)                                                                                                                                                                                                                                     
        self.relu_conv1 = nn.ReLU()                                                                                                                                                                                                                                         
        self.pool1 = nn.MaxPool2d(2, 2)                                                                                                                                                                                                                                     
        self.conv2 = nn.Conv2d(6, 16, 5)                                                                                                                                                                                                                                    
        self.relu_conv2 = nn.ReLU()                                                                                                                                                                                                                                         
        self.pool2 = nn.MaxPool2d(2, 2)                                                                                                                                                                                                                                     
        self.fc1 = nn.Linear(16 * 5 * 5, 120)                                                                                                                                                                                                                               
        self.relu_fc1 = nn.ReLU()                                                                                                                                                                                                                                           
        self.fc2 = nn.Linear(120, 84)                                                                                                                                                                                                                                       
        self.relu_fc2 = nn.ReLU()                                                                                                                                                                                                                                           
        self.fc3 = nn.Linear(84, 10)                                                                                                                                                                                                                                        
                                                                                                                                                                                                                                                                            
        self.noise_conv1 = torch.randn(self.conv1.weight.size())*0.6 + 0                                                                                                                                                                                      
        self.noise_conv2 = torch.randn(self.conv2.weight.size())*0.6 + 0                                                                                                                                                                                      
                                                                                                                                                                                                                                                                            
    def forward(self, x):
        add_noise(self.conv1.weight, self.noise_conv1)     
        x = self.conv1(x)                                                                                                                                                                                                                                                   
                                                                                                                                                                                      
        x = self.pool1(self.relu_conv1(x))
        add_noise(self.conv2.weight, self.noise_conv2)
        x = self.conv2(x)
        
        x = self.pool2(self.relu_conv2(x))                                                                                                                                                                                                                                                                                                                                                                                                                                                
        x = x.view(-1, 16 * 5 * 5)                                                                                                                                                                                                                          
        x = self.relu_fc1(self.fc1(x))                                                                                                                                                                                                                                      
        x = self.relu_fc2(self.fc2(x))                                                                                                                                                                                                                                      
        x = self.fc3(x)                                                                                                                                                                                                                                                     
        return x


def add_noise(weights, noise):                                                                                                                                                                                                                                              
    with torch.no_grad():                                                                                                                                                                                                                                                   
        weights.add_(noise)                                                                                                                                                                                                                                                                                                                                                                                                                                     
    

model = Simplenet()
x = torch.randn(1, 3, 32, 32)
out = model(x)
out.mean().backward()
1 Like

Hi, thank you for answering!!! I tried your suggestion but the network still couldn’t converge, the loss is now become nan after 2 epoch. Any though why?

I’m not familiar with your use case and don’t know why you are adding a constant noise to the conv filters, but these noise tensors might just be too aggressive.
A simple overfitting test shows that the model is properly learning, but the noise seems to disrupt the training:

class Simplenet(nn.Module):                                                                                                                                                                                                                                                 
    def __init__(self):                                                                                                                                                                                                                                                     
        super(Simplenet, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)                                                                                                                                                                                                                                     
        self.relu_conv1 = nn.ReLU()                                                                                                                                                                                                                                         
        self.pool1 = nn.MaxPool2d(2, 2)                                                                                                                                                                                                                                     
        self.conv2 = nn.Conv2d(6, 16, 5)                                                                                                                                                                                                                                    
        self.relu_conv2 = nn.ReLU()                                                                                                                                                                                                                                         
        self.pool2 = nn.MaxPool2d(2, 2)                                                                                                                                                                                                                                     
        self.fc1 = nn.Linear(16 * 5 * 5, 120)                                                                                                                                                                                                                               
        self.relu_fc1 = nn.ReLU()                                                                                                                                                                                                                                           
        self.fc2 = nn.Linear(120, 84)                                                                                                                                                                                                                                       
        self.relu_fc2 = nn.ReLU()                                                                                                                                                                                                                                           
        self.fc3 = nn.Linear(84, 10)                                                                                                                                                                                                                                        
                                                                                                                                                                                                                                                                            
        self.noise_conv1 = torch.randn(self.conv1.weight.size())*0.6 + 0                                                                                                                                                                                      
        self.noise_conv2 = torch.randn(self.conv2.weight.size())*0.6 + 0                                                                                                                                                                                      
                                                                                                                                                                                                                                                                            
    def forward(self, x):
        add_noise(self.conv1.weight, self.noise_conv1)     
        x = self.conv1(x)                                                                                                                                                                                                                                                   
                                                                                                                                                                                      
        x = self.pool1(self.relu_conv1(x))
        add_noise(self.conv2.weight, self.noise_conv2)
        x = self.conv2(x)
        
        x = self.pool2(self.relu_conv2(x))                                                                                                                                                                                                                                  
        # x = self.pool1(self.relu_conv1(self.conv1(x)))                                                                                                                                                                                                                    
        # x = self.pool2(self.relu_conv2(self.conv2(x)))                                                                                                                                                                                                                    
        x = x.view(-1, 16 * 5 * 5)                                                                                                                                                                                                                          
        x = self.relu_fc1(self.fc1(x))                                                                                                                                                                                                                                      
        x = self.relu_fc2(self.fc2(x))                                                                                                                                                                                                                                      
        x = self.fc3(x)                                                                                                                                                                                                                                                     
        return x


def add_noise(weights, noise):                                                                                                                                                                                                                                              
    with torch.no_grad():                                                                                                                                                                                                                                                   
        weights.add_(noise)                                                                                                                                                                                                                                                                                                                                                                                                                                     
    

model = Simplenet()
data = torch.randn(8, 3, 32, 32)
target = torch.randint(0, 10, (8, ))
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

nb_epochs= 100
for epoch in range(nb_epochs):
    optimizer.zero_grad()
    output = model(data)
    loss = criterion(output, target)
    loss.backward()
    optimizer.step()
    print('epoch {}, loss {:.3f}'.format(epoch, loss.item()))

> epoch 0, loss 11.253
epoch 1, loss 29.167
epoch 2, loss 66.412
epoch 3, loss 35.963
epoch 4, loss 73.107
epoch 5, loss 61.738
epoch 6, loss 57.288
epoch 7, loss 47.320
epoch 8, loss 65.957
epoch 9, loss 0.000
epoch 10, loss 28.883
epoch 11, loss 15.165
epoch 12, loss 31.132
epoch 13, loss 0.000
epoch 14, loss 52.350
epoch 15, loss 8.386
epoch 16, loss 0.004
epoch 17, loss 154.521
epoch 18, loss 0.000
epoch 19, loss 63.121
epoch 20, loss 0.000
epoch 21, loss 0.000
epoch 22, loss 93.031
epoch 23, loss 0.000
epoch 24, loss 23.481
epoch 25, loss 0.000
epoch 26, loss 147.731
epoch 27, loss 67.255
epoch 28, loss 0.000
epoch 29, loss 65.383
epoch 30, loss 0.000
epoch 31, loss 220.798
epoch 32, loss 0.000
epoch 33, loss 67.999
epoch 34, loss 0.000
epoch 35, loss 128.444
epoch 36, loss 0.000
epoch 37, loss 0.000
epoch 38, loss 0.000
epoch 39, loss 0.000
epoch 40, loss 0.000
epoch 41, loss 71.423
epoch 42, loss 0.000
epoch 43, loss 44.055
epoch 44, loss 186.880
epoch 45, loss 0.000
epoch 46, loss 0.000
epoch 47, loss 122.831
epoch 48, loss 97.819
epoch 49, loss 0.000
epoch 50, loss 0.000
epoch 51, loss 0.000
epoch 52, loss 0.000
epoch 53, loss 0.000
epoch 54, loss 278.068
epoch 55, loss 0.000
epoch 56, loss 0.000
epoch 57, loss 622.002
epoch 58, loss 0.000
epoch 59, loss 0.000
epoch 60, loss 0.000
epoch 61, loss 678.884
epoch 62, loss 0.000
epoch 63, loss 0.000
epoch 64, loss 11.639
epoch 65, loss 0.000
epoch 66, loss 449.992
epoch 67, loss 0.000
epoch 68, loss 0.000
epoch 69, loss 0.000
epoch 70, loss 0.000
epoch 71, loss 0.000
epoch 72, loss 131.694
epoch 73, loss 2245.142
epoch 74, loss 0.000
epoch 75, loss 0.000
epoch 76, loss 2373.450
epoch 77, loss 1397.521
epoch 78, loss 0.000
epoch 79, loss 273.361
epoch 80, loss 0.000
epoch 81, loss 492.052
epoch 82, loss 0.000
epoch 83, loss 109.878
epoch 84, loss 0.000
epoch 85, loss 0.000
epoch 86, loss 0.000
epoch 87, loss 110.755
epoch 88, loss 0.000
epoch 89, loss 2689.813
epoch 90, loss 0.000
epoch 91, loss 0.000
epoch 92, loss 0.000
epoch 93, loss 0.000
epoch 94, loss 0.000
epoch 95, loss 1625.105
epoch 96, loss 0.000
epoch 97, loss 0.000
epoch 98, loss 0.000
epoch 99, loss 0.000

If you remove the noise (or reduce it), the training behaves much better.

2 Likes

Thanks for pointing out the problem!!! Actually I am trying to add faults into weight to test the error-tollerbility of a model. I’m going to add noise as the formular below, but I want to try adding simpler noise first:
Screenshot_20210120-161354_Xodo Docs

The paper points out that sigma can range from 0.6 to 2, so I thought that the range of the noise.I tried adding smaller noise but the results after some epoch is still not promissing. Do you think adding random noise right into the forward pass will change any thing from the results?

Injecting noise into the model might act as a regularizer, but note that your current noise is static and you would most likely want to resample it in each forward pass. I don’t know how large the stddev should be to work properly.

2 Likes

Dear Ptrblck,

Reducing noise and resample it for each forward pass seem to work. Thank you a lot for your help!!!

Hey, I also want to insert some Gaussian noise into my DCGAN, not into the conv2Ds weights, but into the conv2Ds outputs. Any tip on how I can do that?
I’ve tried the following:

class Discriminator(nn.Module):
    def __init__(self, ngpu):
        super(Discriminator, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, 4, 2, 1, bias=False)
        self.LeakyRelu = nn.LeakyReLU(0.2, inplace=True)
        self.dropout = nn.Dropout(0.4, inplace=False)
        self.conv2 = nn.Conv2d(64, 64 * 2, 4, 2, 1, bias=False)
        self.batchnorm2 = nn.BatchNorm2d(64 * 2)
        self.conv3 = nn.Conv2d(64 * 2, 64 * 4, 4, 2, 1, bias=False)
        self.batchnorm3 = nn.BatchNorm2d(64 * 4)
        self.conv4 = nn.Conv2d(64 * 4, 64 * 8, 4, 2, 1, bias=False)
        self.batchnorm4 = nn.BatchNorm2d(64 * 8)
        self.conv5 = nn.Conv2d(64 * 8, 1, 4, 1, 0, bias=False)
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, input):
        x = self.conv1(input)
        x = torch.randn(x.size()).to(device) + x
        x = self.LeakyRelu(x)
        x = self.dropout(x)
        x = self.conv2(x)
        x = torch.randn(x.size()).to(device) + x
        x = self.batchnorm2(x)
        x = self.LeakyRelu(x)
        x = self.dropout(x)
        x = self.conv3(x)
        x = torch.randn(x.size()).to(device) + x
        x = self.batchnorm3(x)
        x = self.LeakyRelu(x)
        x = self.dropout(x)
        x = self.conv4(x)
        x = torch.randn(x.size()).to(device) + x
        x = self.batchnorm4(x)
        x = self.LeakyRelu(x)
        x = self.dropout
        x = self.conv5(x)
        x = self.sigmoid(x)

        return x

Which gave me this error:

TypeError: conv2d() received an invalid combination of arguments - got (Dropout, Parameter, NoneType, tuple, tuple, tuple, int), but expected one of:
 * (Tensor input, Tensor weight, Tensor bias, tuple of ints stride, tuple of ints padding, tuple of ints dilation, int groups)
      didn't match because some of the arguments have invalid types: (Dropout, Parameter, NoneType, tuple, tuple, tuple, int)
 * (Tensor input, Tensor weight, Tensor bias, tuple of ints stride, str padding, tuple of ints dilation, int groups)
      didn't match because some of the arguments have invalid types: (Dropout, Parameter, NoneType, tuple, tuple, tuple, int)

I know the error is within the line x = torch.randn() + x, but I don’t know how to fix it. Maybe I should round the output of it, but then I wouldn’t be adding a Gaussian noise, right?

You have a small typo in the code here:

        x = self.LeakyRelu(x)
        x = self.dropout
        x = self.conv5(x)

forgot to call self.dropout(x) and are passing the Dropout module to self.conv5.

1 Like

@thnguyen996 In what paper did you find this formula?

I know something related/similar under the term “variational weight noise” but I’m not sure really where that terms come from.

I also found this other related post: Backpropagating through noise

Can you explain what you mean by resampling it in each pass?

Instead of creating the noise once in the __init__ and adding it to the parameters, I recommended to recreate the noise in the forward pass, so that it would be actually random instead of a static bias.