Writing a simple Gaussian noise layer in Pytorch

I wrote a simple noise layer for my network.

def gaussian_noise(inputs, mean=0, stddev=0.01):
    input = inputs.cpu()
    input_array = input.data.numpy()

    noise = np.random.normal(loc=mean, scale=stddev, size=np.shape(input_array))

    out = np.add(input_array, noise)

    output_tensor = torch.from_numpy(out)
    out_tensor = Variable(output_tensor)
    out = out_tensor.cuda()
    out = out.float()
    return out

Switching to .cpu() and doing my stuff in numpy seems to affect the network’s training (it’s not training). Is there another, simpler way to do this, without going to numpy and back (and especially without going to the .cpu() ) ?

EDIT:

This seems to work:

def gaussian(ins, is_training, stddev=0.2):
    if is_training:
        return ins + Variable(torch.randn(ins.size()).cuda() * stddev)
    return ins

But I’m not sure if I can move the mean just by adding a real number.

1 Like

Yes, you can move the mean by adding the mean to the output of the normal variable.
But, a maybe better way of doing it is to use the normal_ function as follows:

def gaussian(ins, is_training, mean, stddev):
    if is_training:
        noise = Variable(ins.data.new(ins.size()).normal_(mean, stddev))
        return ins + noise
    return ins
6 Likes

To speed things up, I was thinking it’d be nice if the variable had been pre-allocated onto the GPU, and instead we just fill in the values:

class DynamicGNoise(nn.Module):
    def __init__(self, shape, std=0.05):
        super().__init__()
        self.noise = Variable(torch.zeros(shape,shape).cuda())
        self.std   = std
        
    def forward(self, x):
        if not self.training: return x
        self.noise.data.normal_(0, std=self.std)
        
        print(x.size(), self.noise.size())
        return x + self.noise

Surprisingly, this did not work:

RuntimeError: size ‘[48 x 48]’ is invalid for input of with 8110080 elements at /home/uapatira/Desktop/pytorch-master/torch/lib/TH/THStorage.c:59

Eh? The output of the print statement is:

torch.Size([5, 704, 48, 48]) torch.Size([48, 48])

From what I understand of pytorch broadcasting semantics, this shouldn’t be a problem. In fact, if I try to: torch.FloatTensor([5, 176, 96, 96]) + torch.FloatTensor([96, 96]), it works just fine, but when I try to do .backwards(), it looks like the broadcasting breaks down?

Solved: Fixed using .expand() on self.noise.

1 Like

In case you don’t know the shape apriori (or aren’t bothered to type it), here is what I normally use :slight_smile:

class GaussianNoise(nn.Module):
    """Gaussian noise regularizer.

    Args:
        sigma (float, optional): relative standard deviation used to generate the
            noise. Relative means that it will be multiplied by the magnitude of
            the value your are adding the noise to. This means that sigma can be
            the same regardless of the scale of the vector.
        is_relative_detach (bool, optional): whether to detach the variable before
            computing the scale of the noise. If `False` then the scale of the noise
            won't be seen as a constant but something to optimize: this will bias the
            network to generate vectors with smaller values.
    """

    def __init__(self, sigma=0.1, is_relative_detach=True):
        super().__init__()
        self.sigma = sigma
        self.is_relative_detach = is_relative_detach
        self.noise = torch.tensor(0).to(device)

    def forward(self, x):
        if self.training and self.sigma != 0:
            scale = self.sigma * x.detach() if self.is_relative_detach else self.sigma * x
            sampled_noise = self.noise.repeat(*x.size()).normal_() * scale
            x = x + sampled_noise
        return x 
5 Likes

I wonder if the addition of noise affects the backward pass through that layer. For example, for quantization layers people typically use straight-through estimator (which is basically just gradient clipping if I remember correctly). Should the gradients be clipped somehow when passing through noise layers as well?

I think yes. https://arxiv.org/pdf/1607.00133.pdf clips the gradient and then adds noise

1 Like

How are all of you dealing with the noise, causing values to go over 1 or under 0? Isn’t this a problem?

I have this:

class noiseLayer_normal(nn.Module):
    def __init__(self, noise_percentage):
        super(noiseLayer_normal, self).__init__()
        self.n_scale = noise_percentage

    def forward(self, x):
        if self.training:
            noise_tensor = torch.normal(0, 0.2, size=x.size()).to(dev) 
            x = x + noise_tensor * self.n_scale
        
            mask_high = (x > 1.0)
            mask_neg = (x < 0.0)
            x[mask_high] = 1
            x[mask_neg] = 0

        return x

But I think all of these masks are slowing down my training. Why do you not include this?

1 Like

Thanks for sharing! Just made a few improvements:
If you are using DataParallel, you may need to register buffer in the init function. Or the variable will not be copied to all devices.
And fixed a bug by adding .float() in the third last line.

class GaussianNoise(nn.Module):
    """Gaussian noise regularizer.

    Args:
        sigma (float, optional): relative standard deviation used to generate the
            noise. Relative means that it will be multiplied by the magnitude of
            the value your are adding the noise to. This means that sigma can be
            the same regardless of the scale of the vector.
        is_relative_detach (bool, optional): whether to detach the variable before
            computing the scale of the noise. If `False` then the scale of the noise
            won't be seen as a constant but something to optimize: this will bias the
            network to generate vectors with smaller values.
    """
    def __init__(self, sigma=0.1, is_relative_detach=True):
        super().__init__()
        self.sigma = sigma
        self.is_relative_detach = is_relative_detach
        self.register_buffer('noise', torch.tensor(0))

    def forward(self, x):
        if self.training and self.sigma != 0:
            scale = self.sigma * x.detach() if self.is_relative_detach else self.sigma * x
            sampled_noise = self.noise.expand(*x.size()).float().normal_() * scale
            x = x + sampled_noise
        return x 
2 Likes

you can try


noise = torch.randn_like(x)
x = ((noise + x).detach() - x).detach() + x

What’s the point for detach operations on line 2?

its STE tricks, to make the gradient pass through as same as no noise.

Will model.eval() treat it correctly in the inference stage?

The posted code uses if self.training, which will be manipulated by calling model.train()/.eval() and should thus return the input in eval() mode directly.

1 Like