After 1.4->1.5 how to do backward() twice

Hi,

The problem is that the original code here was computing wrong gradients.

You can modify this quite easily by overriding the linear forward function for this case:

class MyLinear(nn.Linear):
    def forward(self, input):
        return F.linear(input, self.weight.clone(), self.bias.clone())

# And use this one later:
self.layer1 = nn.MyLinear(10, 1)
1 Like