Optimizer inside model, good practice?


I saw today on this repository the practice of putting optimizer and device inside the model:

import torch as T
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np

class GenericNetwork(nn.Module):
    def __init__(self, alpha, input_dims, fc1_dims, fc2_dims,
        super(GenericNetwork, self).__init__()
        self.input_dims = input_dims
        self.fc1_dims = fc1_dims
        self.fc2_dims = fc2_dims
        self.n_actions = n_actions
        self.fc1 = nn.Linear(*self.input_dims, self.fc1_dims)
        self.fc2 = nn.Linear(self.fc1_dims, self.fc2_dims)
        self.fc3 = nn.Linear(self.fc2_dims, self.n_actions)
        self.optimizer = optim.Adam(self.parameters(), lr=alpha)
        self.device = T.device('cuda:0' if T.cuda.is_available() else 'cuda:1')

    def forward(self, observation):
        state = T.tensor(observation, dtype=T.float).to(self.device)
        x = F.relu(self.fc1(state))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)

        return x

I felt that it was petty ingenuous, especially when you are in a case with several networks (it avoids to have two times more self variables)

I was wondering if this was a good, or a bad practice and if there is some drawbacks. I never saw this before, so i am guessing that there is a good reason ?

I personally don’t use this workflow, as I don’t think that an optimizer is necessarily bound to a model.
Of course your use case might use a single model and optimizer and thus this workflow will certainly work for your use case, but it’s “cleaner” to me to use them as separate objects.

One potential drawback would also be the creation of checkpoints. While model.state_dict() would properly return all parameters and buffers, you would now need to access the optimizer object through the model to save and load its state_dict. Again, this comes down to your coding style, so I’m sure there might also be valid reasons to use this approach.

Also, I wouldn’t add the device information inside the model directly, as this would break e.g. nn.DataParallel.

1 Like

Thanks for your answer ! It does seems indeed inconvenient for saving/loading
I am not sure I did understand the nn.DataParallel part, is it possible to give a very small example or explanation ?

The issue for nn.DataParallel is that the model uses a hardcoded device as cuda:0, if CUDA is available (and cuda:1, which will fail so I assume 'cpu' should be used in the else statement).
However, nn.DataParallel will copy the model to each specified device in each forward pass, will split the input in dim0, and send each chunk to the corresponding device.
For this to work you cannot use hardcoded device arguments inside the model or forward, since nn.DataParallel will take care of it. If you need to create new tensors inside the forward, you would have to use the currently used device for this model replica (e.g. via new_tensor = torch.randn(1, device=x.device)).

While this code works fine on the GPU:

model = GenericNetwork(1., (1,), 1, 1, 1)
x = T.randn(8, 1)
out = model(x)

You’ll get a device mismatch error, if you use model = nn.DataParallel(model):

RuntimeError: Expected tensor for 'out' to have the same device as tensor for argument #2 'mat1'; but device 1 does not equal 0 (while checking arguments for addmm)

for the aforementioned reason.

1 Like