Does DataParallel applies own attribute changes to replicas?

Hello.
I have a question about DataParallel.

Assume I have a model, which have its own attribute ‘mode’. The model’s forward flow can be changed by setting this ‘mode’.
For example:

class myModel(nn.Module):
    def __init__(self, mode='A'):
        # do essential initialization...
        self.mode = mode
        self.selectible_layers = nn.ModuleDict(
            {'A': moduleA(), 'B': moduleB(), ...}
        )

    def forward(self, x):
         # forwarding...
         x = self.selectible_layers[self.mode](x)
         # forwarding...
         return x

When I parallelize this model into multiple GPUs, then I would access this ‘mode’ attribute by using model.module because nn.DataParallel don’t know about ‘mode’.

Then,

model = myModel(mode='A') # initialize with mode 'A'
model = nn.DataParallel(model, device_ids = [0, 1, 2, 3]) # distributed into multiple GPUs

for mode in ['A', 'B', 'C', 'D']:
    model.module.mode = mode # change mode
    
    # (?)

In (?) on above code, does nn.DataParallel guarantee that all model replicas in multiple GPUs have same ‘mode’ when changing it?
I worried about that the ‘mode’ in replicas still have mode ‘A’, while original model(on host GPU; maybe 0)'s ‘mode’ changes.

+)
I tried myself but I don’t know how to access each replica in multiple GPUs. How to access them?

The following method is called by DataParallel to create replicas. So the attributes in __dict__ should be replicated as well. But as the flow is “DataParallel forward” -> “replicate models” -> “app model forward”, you need to make sure that the mode is set properly before calling DataParallel forward.

I tried myself but I don’t know how to access each replica in multiple GPUs. How to access them?

The replicas are created in every forward pass of DataParallel module. To access them and check, you can modify your model’s forward function and print out the mode value.

1 Like

Hello there!

Is it also possible to access the replicas directly somehow?
E.g., accessing the model.parameters() of each replica?