Using LBFGS, CUDA and "custom" parameters

Hi,
I ran into the following issue: I have a simple neural network class with one custom variable ‘a’:

class NeuralNetwork(torch.nn.Module):
    def __init__(self, a):
        super().__init__()
        
        self.a = a
        self.fc1 = torch.nn.Linear(3,5)

    
    def forward(self, x):
        x = self.a * self.fc1(x)
        return x

I am trying to do the following: create an instance of NeuralNetwork, put it to cuda and train it with the LBFGS optimizer like follows:

a = torch.tensor(1.0, requires_grad = True)

device = torch.device("cuda:0")

net = NeuralNetwork(a)
net.to(device)

optimizer = torch.optim.LBFGS([net.a] + list(net.parameters()))
# optimizer = torch.optim.Adam([net.a] + list(net.parameters()))

xIn = torch.tensor((1.0,2.0,3.0), device = device)

def closure():
    optimizer.zero_grad()
    xOut = net(xIn)
    loss = xOut.sum() + a
    loss.backward()
    return loss

optimizer.step(closure)

The variable ‘a’ is also a parameter that should be changed by my optimizer, hence it is included in the parameters in the initialization of ‘optimizer’. When running the whole code, I run into the following error:

RuntimeError: All input tensors must be on the same device. Received cpu and cuda:0

in the last line.
I thought that this might happen because ‘a’ is for some reason not on the GPU, as when leaving ‘a’ out of the optimizer parameters, the error is gone. But then again, when swapping LBFGS for Adam (in the code above comment the LBFGS-line and uncomment the Adam-line), but keeping ‘a’ in the optimizer parameters, the error is also gone.
Does anyone have an idea what might be the problem here and how to solve it?

Thank you!
Chris

yes, it is not auto-moved, as it is not considered a part of the model, just because you’ve assigned it as an object attribute (without nn.Parameter wrapper). adam probably mitigates this error by having more robust code.

Hi, I now changed the definition of a to

a = torch.nn.Parameter(torch.tensor(1.0), requires_grad = True)

and now it seems to work. Thank you!