How to access a class object when I use torch.nn.DataParallel()?

I want to train my model using PyTorch with multiple GPUs. I included the following line:

model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids)

Then, I tried to access the optimizer that was defined in my model definition:

G_opt = model.module.optimizer_G

However, I got an error:

AttributeError: ‘DataParallel’ object has no attribute optimizer_G

I think it is related with the definition of optimizer in my model definition. It works when I use single GPU without torch.nn.DataParallel. But it does not work with multi GPUs even though I call with module and I could not find the solution.

Here is the model definition:

class MyModel(torch.nn.Module):
    ...
   self.optimizer_G = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))   

I used Pix2PixHD implementation in GitHub if you want to see the full code.

Thank you, Best.

I do not know if you have tried this already, but since the error code tells you that the object is of the DataParallel class, can you make sure that the module is the class you expect?

Something like:

print(type(model.module))

Here is the output of the print(type(model.module))

<class ‘torch.nn.parallel.data_parallel.DataParallel’>

Another point in the original code is that when there is no torch.nn.DataParallel() code to use multi GPUs, it still uses model.module to access to the class objects and it works. But when I include torch.nn.DataParallel(), it throws the error in my first post.

Edit: When I used model.module.module.optimizer_G, it worked.

This is what I suspected was going to be the answer.
You could carefully go through all of your code (or at least the parts where the model gets used), because if this happens:

It means that the model is being wrapped into the DataParallel class before, and I don’t know if wrapping it twice will cause any bugs or unexpected behavior.

Also, the recommended way is to use DistributedDataParallel whenever possible instead of DataParallel, since it creates one process per GPU. Here is a link to a tutorial that I found very helpful for learning how to use DistributedDataParallel.
I hope it helps!

1 Like