Parameters initialisation

def weights_init(m):
    if isinstance(m, nn.Conv2d):
        torch.nn.init.xavier_uniform(m.weight.data)

Alternatively, you could use with torch.no_grad(): and remove the .data call.

1 Like