I’m new in PyTorch. I am trying to write a function that adds some arbitrary Gaussian noise to the wights during the training process. my code is like this
for m in model.modules():
if hasattr(m, ‘weight’):
m.weight.add_(np.random.normal(my_mean, my_std, m.shape)*noise_strength)
and my question is the shape of “m” how can I create noise with its shape?
I’m confused how should I do that, any help would be appreciated.
do you mean something like this?
Actually my model is some triple network like below:
self.convnet = nn.Sequential(nn.Conv2d(1, 32, 5), nn.PReLU(),
nn.Conv2d(32, 64, 5), nn.PReLU(),
self.fc = nn.Sequential(nn.Linear(64 * 4 * 4, 256),
def forward(self, x):
output = self.convnet(x)
output = output.view(output.size(), -1)
output = self.fc(output)
def get_embedding(self, x):
def init(self, embedding_net):
self.embedding_net = embedding_net
embedding_net = EmbeddingNet()
model = TripletNet(embedding_net)
So, I am using define some arbitrary mean and std then use them into make Gaussian noise how can i add this noise during the training process of this network in each epoch. because in each epoch i define a new mean and std.
do you mean sampling from normal distribution with new mean and std, then we could use,
normal_dist = torch.distributions.Normal(loc=torch.tensor([0.]), scale=torch.tensor([1.0]))
replace torch.tensor([0.]) with mean value, and torch.tensor([1.0]) with std value
one example would be
x = nn.Linear(3, 3)
t = normal_dist.sample((x.weight.view(-1).size())).reshape(x.weight.size())
First get the parameters of your model as a vector
from torch.nn.utils import vector_to_parameters, parameters_to_vector
param_vector = parameters_to_vector(model.parameters())
Then sample a gaussian noise of the same size as this vector and add it.
n_params = len(param_vector)
noise = Normal(0, 1).sample_n(n_params)
Finally, load the parameters back to your model.
hi Ravin Jain, thanks for your good comment. I have a question about your code:
x.weigh.data.add_(t) or x.weight.add_(t) what is the difference between of them and which is correct?