def gaussian_noise(inputs, mean=0, stddev=0.01):
input = inputs.cpu()
input_array = input.data.numpy()
noise = np.random.normal(loc=mean, scale=stddev, size=np.shape(input_array))
out = np.add(input_array, noise)
output_tensor = torch.from_numpy(out)
out_tensor = Variable(output_tensor)
out = out_tensor.cuda()
out = out.float()
return out
Switching to .cpu() and doing my stuff in numpy seems to affect the networkâs training (itâs not training). Is there another, simpler way to do this, without going to numpy and back (and especially without going to the .cpu() ) ?
EDIT:
This seems to work:
def gaussian(ins, is_training, stddev=0.2):
if is_training:
return ins + Variable(torch.randn(ins.size()).cuda() * stddev)
return ins
But Iâm not sure if I can move the mean just by adding a real number.
Yes, you can move the mean by adding the mean to the output of the normal variable.
But, a maybe better way of doing it is to use the normal_ function as follows:
def gaussian(ins, is_training, mean, stddev):
if is_training:
noise = Variable(ins.data.new(ins.size()).normal_(mean, stddev))
return ins + noise
return ins
To speed things up, I was thinking itâd be nice if the variable had been pre-allocated onto the GPU, and instead we just fill in the values:
class DynamicGNoise(nn.Module):
def __init__(self, shape, std=0.05):
super().__init__()
self.noise = Variable(torch.zeros(shape,shape).cuda())
self.std = std
def forward(self, x):
if not self.training: return x
self.noise.data.normal_(0, std=self.std)
print(x.size(), self.noise.size())
return x + self.noise
Surprisingly, this did not work:
RuntimeError: size â[48 x 48]â is invalid for input of with 8110080 elements at /home/uapatira/Desktop/pytorch-master/torch/lib/TH/THStorage.c:59
Eh? The output of the print statement is:
torch.Size([5, 704, 48, 48]) torch.Size([48, 48])
From what I understand of pytorch broadcasting semantics, this shouldnât be a problem. In fact, if I try to: torch.FloatTensor([5, 176, 96, 96]) + torch.FloatTensor([96, 96]), it works just fine, but when I try to do .backwards(), it looks like the broadcasting breaks down?
In case you donât know the shape apriori (or arenât bothered to type it), here is what I normally use
class GaussianNoise(nn.Module):
"""Gaussian noise regularizer.
Args:
sigma (float, optional): relative standard deviation used to generate the
noise. Relative means that it will be multiplied by the magnitude of
the value your are adding the noise to. This means that sigma can be
the same regardless of the scale of the vector.
is_relative_detach (bool, optional): whether to detach the variable before
computing the scale of the noise. If `False` then the scale of the noise
won't be seen as a constant but something to optimize: this will bias the
network to generate vectors with smaller values.
"""
def __init__(self, sigma=0.1, is_relative_detach=True):
super().__init__()
self.sigma = sigma
self.is_relative_detach = is_relative_detach
self.noise = torch.tensor(0).to(device)
def forward(self, x):
if self.training and self.sigma != 0:
scale = self.sigma * x.detach() if self.is_relative_detach else self.sigma * x
sampled_noise = self.noise.repeat(*x.size()).normal_() * scale
x = x + sampled_noise
return x
I wonder if the addition of noise affects the backward pass through that layer. For example, for quantization layers people typically use straight-through estimator (which is basically just gradient clipping if I remember correctly). Should the gradients be clipped somehow when passing through noise layers as well?
Thanks for sharing! Just made a few improvements:
If you are using DataParallel, you may need to register buffer in the init function. Or the variable will not be copied to all devices.
And fixed a bug by adding .float() in the third last line.
class GaussianNoise(nn.Module):
"""Gaussian noise regularizer.
Args:
sigma (float, optional): relative standard deviation used to generate the
noise. Relative means that it will be multiplied by the magnitude of
the value your are adding the noise to. This means that sigma can be
the same regardless of the scale of the vector.
is_relative_detach (bool, optional): whether to detach the variable before
computing the scale of the noise. If `False` then the scale of the noise
won't be seen as a constant but something to optimize: this will bias the
network to generate vectors with smaller values.
"""
def __init__(self, sigma=0.1, is_relative_detach=True):
super().__init__()
self.sigma = sigma
self.is_relative_detach = is_relative_detach
self.register_buffer('noise', torch.tensor(0))
def forward(self, x):
if self.training and self.sigma != 0:
scale = self.sigma * x.detach() if self.is_relative_detach else self.sigma * x
sampled_noise = self.noise.expand(*x.size()).float().normal_() * scale
x = x + sampled_noise
return x
The posted code uses if self.training, which will be manipulated by calling model.train()/.eval() and should thus return the input in eval() mode directly.