The issue is that you are creating non leaf variables by casting the tensors.
The reason that the cast works sometimes it that your cast is a no op, since the underlying data is already in the desired format. E.g:
weights1 = torch.tensor(np.random.randn(784, 128).astype(np.float32), requires_grad=True).float()
The call to .float()
doesn’t do anything as the data is already in this type.
This line of code
weights1 = torch.tensor(np.random.randn(784, 128), requires_grad=True).float()
creates a new tensor using the operation float()
. weights1
is therefore not a leaf variable anymore.
Have a look at this excellent explanation for other use cases.