def weights_init(m):
if isinstance(m, nn.Conv2d):
torch.nn.init.xavier_uniform(m.weight.data)
Alternatively, you could use with torch.no_grad(): and remove the .data call.
def weights_init(m):
if isinstance(m, nn.Conv2d):
torch.nn.init.xavier_uniform(m.weight.data)
Alternatively, you could use with torch.no_grad(): and remove the .data call.