You can remove all the .data and replace them with:
@torch.no_grad()
def weights_init(m):
# Your code
And yes this will reinitialize all the weights with random values.
You might be interested by the torch.nn.init package that gives you many common initialization methods.
@torch.no_grad()
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
m.weight.normal_(0.0, 0.02)
if classname.find('Linear') != -1:
# get the number of the inputs
n = m.in_features
y = 1.0 / np.sqrt(n)
m.weight.uniform_(-y, y)
m.bias.fill_(0)
elif classname.find('BatchNorm') != -1:
m.normal_(m.weight, mean=1, std=0.02)
m.constant_(m.bias, 0)
is @torch.no_grad() different from torch.no_grad() ?
The two are the same, one is a context manager, the other is a function decorator. It is the same as running your whole function in with torch.no_grad(). So it is quite convenient.
The problem is that .data can hide some errors and give you wrong gradients.
For example, this issue poped up today: https://github.com/pytorch/pytorch/issues/30073 which is caused by the use of .data in the old codebase.