Hi,
The problem is that the original code here was computing wrong gradients.
You can modify this quite easily by overriding the linear forward function for this case:
class MyLinear(nn.Linear):
def forward(self, input):
return F.linear(input, self.weight.clone(), self.bias.clone())
# And use this one later:
self.layer1 = nn.MyLinear(10, 1)