Making a wrapper around nn.DataParallel to access module attributes is safe?

Hello everyone.
I am using pytorch nn.DataParallel, but I cannot access any variables within may model with it (I will get AttributeError).

So, I make a wrapper like this:

class DataParallelWrapper(nn.DataParallel):
    def __getattr__(self, name):
        try:
            return super().__getattr__(name)
        except AttributeError:
            return getattr(self.module, name)

Being completely zero in how nn.DataParallel is being implemented, it raises a question.
Is this a safe approach?

If it matters, I am using PyTorch 1.1 (need to update code before upgrading to new shiny version).

DataParallel does have a self.module variable:

The following code works for me:

import torch.nn as nn
import torch
from torch.nn.parallel import DataParallel

class DataParallelWrapper(DataParallel):
    def __init__(self, module):
        super().__init__(module)


    def __getattr__(self, name):
        try:
            return super().__getattr__(name)
        except AttributeError:
            return getattr(self.module, name)


dpw = DataParallelWrapper(nn.Linear(2, 2))
print(getattr(dpw, "forward"))
print(getattr(dpw, "weight"))

I am fully aware of the ability of accessing my network through nn.DataParallel.module. However, I need to make sure that my code is executable (no AttributeError when I am trying to access my defined-variable in my network) whether the user is using single GPU or multiple GPUs. Hence, the DataPrallelWrapper.

The only problem is I do not know whether this is a good approach or not (will my approach invite any data race or async problem?).

This should be fine I think. The main logic of DataParallel is implemented in the forward() function, it replicates the given module to available devices, scatters input, launches one thread per device to process the scattered input, and then joins threads and gathers the outputs. So as long as you are not trying to modify module parameters or gradients concurrently when executing DataParallel.forward(), I don’t see an issue here.

1 Like