Effect of calling model.cuda() after constructing an optimizer

According to the torch.optim docs it is recommended to move a model to GPU before constructing an optimizer i.e.

model.cuda()

optimizer = optim.SGD(model.parameters(), lr = 0.01)

# Do optimization

I only became aware of this recently and have often constructed the optimizer before moving the model to GPU, i.e.

optimizer = optim.SGD(model.parameters(), lr = 0.01)

model.cuda()

# Do optimization

However, I’ve never noticed anything going “wrong” when doing this and I have always got the results I expected in those situations.

So my question, is what can, in practice, go wrong when moving the model to GPU after constructing an optimizer? And why is this behaviour not recommended?

Thank you!

17 Likes

It is fine in case of SGD. However, if the optimizer constructs some buffer in __init__ basing on the parameter type, then you will have some problem, e.g. https://github.com/pytorch/pytorch/blob/master/torch/optim/adagrad.py#L30

1 Like

Thanks for your response!
So for example, if you use Adam or AdaGrad or any of the algorithms that construct a buffer in __init__, what can go wrong in the optimization? Will the optimizer not keep track of this buffer appropriately?

It will error e.g. when you try to add the .grad with the buffer.

1 Like

I’ve tested this and it doesn’t seem to error on my side. For example, if we modify the MNIST pytorch example, by changing lines 70-75 from

model = Net()
if args.cuda:
    model.cuda()

optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)

to

model = Net()

optimizer = optim.Adam(model.parameters(), lr=args.lr)

if args.cuda:
    model.cuda()

the code still runs fine (no error) and gives you expected results (e.g. 95% accuracy after 1 epoch). If you print the optimizer.state dict, it returns something like this in both cases (i.e. when model was put on GPU before and after creating optimizer):

defaultdict(<type 'dict'>, {Parameter containing:
-0.0099
-0.0659
-0.4953
-0.0524
-0.2689
-0.1911
-0.0981
-0.1887
 0.1477
-0.1847
 0.0446
-0.0893
-0.2479
-0.0239
 0.0504
-0.1310
-0.1203
-0.1727
-0.1563
-0.0666
[torch.cuda.FloatTensor of size 20 (GPU 0)]
: {'exp_avg_sq': 
1.00000e-04 *
  7.2218
  0.0056
  1.4282
...
1 Like

To be clear, it’s awesome if this works in both cases. I just want to make sure that there is not some internal error happening which I’m unaware of and which could somehow mess up the results. I’ve written projects for work and for research in pytorch, where I used to put the model on GPU after creating the optimizer (as I was unaware of this issue) and I want to make sure I don’t need to go back and redo all the work in case something is not working.
Thank you!

5 Likes

Can you try adagrad? Adam doesn’t initialize such buffers: https://github.com/pytorch/pytorch/blob/master/torch/optim/adam.py

2 Likes

Perfect, I’m able to produce that error with Adagrad (see below). Thank you for your help! As a last question, how come Adagrad doesn’t initialize its buffer in step instead of __init__ like in Adam? Wouldn’t this remove the problem altogether? (assuming step isn’t called before moving the model to GPU)

TypeError: addcmul_ received an invalid combination of arguments - got (int, torch.cuda.FloatTensor, torch.cuda.FloatTensor), but expected one of:
 * (torch.FloatTensor tensor1, torch.FloatTensor tensor2)
 * (torch.SparseFloatTensor tensor1, torch.SparseFloatTensor tensor2)
 * (float value, torch.FloatTensor tensor1, torch.FloatTensor tensor2)
      didn't match because some of the arguments have invalid types: (int, torch.cuda.FloatTensor, torch.cuda.FloatTensor)
 * (float value, torch.SparseFloatTensor tensor1, torch.SparseFloatTensor tensor2)
      didn't match because some of the arguments have invalid types: (int, torch.cuda.FloatTensor, torch.cuda.FloatTensor)
1 Like

Yeah, it could be done that way. Feel free to submit a PR if you want to :slight_smile: