def weights_init(m):
if isinstance(m, nn.Conv2d):
torch.nn.init.xavier_uniform(m.weight.data)
Alternatively, you could use with torch.no_grad():
and remove the .data
call.
def weights_init(m):
if isinstance(m, nn.Conv2d):
torch.nn.init.xavier_uniform(m.weight.data)
Alternatively, you could use with torch.no_grad():
and remove the .data
call.