Overriding Module.parameters() with DataParallel

Hey,

I have a network which overrides the parameters() function to only include trainable parameters. This has worked well until I tried to run it with DataParallel. I guess I was not supposed to override it because DataParallel does not work with my model. Here’s an example:

# Python 3.6
# Pytorch 4.1 installed via anaconda
import torch
from torch import nn

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.module_list = nn.ModuleList()
        self.module_list.append(nn.Linear(10, 10))
        self.module_list.append(nn.Linear(10, 10))
    
    def parameters(self, only_trainable=True):
        for param in self.module_list.parameters():
            if only_trainable and not param.requires_grad:
                continue
            yield param

net = nn.DataParallel(Net().cuda())
net(torch.rand(1, 10))

This throws a NotImplementedError. If i set requires_grad=False on a module I instead get a KeyError in torch/nn/parallel/replicate.py

The solution is easy, just rename the function to something like trainable_parameters().

However, I’m a bit curious, should parameters() never be overridden? It worked perfectly fine when running on single GPU but I guess the function is used internally in some other parts of Pytorch? Or did I just not use yield properly?

Thanks in advance

I’m not completely sure, but I think your parameters method filtering only parameters requiring gradients will break replicate as your current method won’t yield all parameters.

I thought so, as I kept getting a KeyError in replicate.py when I had frozen a layer. What I don’t get is why it is throwing a NotImplementedError in module.py when I return all parameters (i.e. all parameters require gradients). That’s why I thought I was yielding parameters incorrectly somehow.

Could you post the error message and the functions it was thrown?

With this class

# Python 3.6
# Pytorch 4.1 installed via anaconda
import torch
from torch import nn

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.module_list = nn.ModuleList()
        self.module_list.append(nn.Linear(10, 10))
        self.module_list.append(nn.Linear(10, 10))
    
    def parameters(self, only_trainable=True):
        for param in self.module_list.parameters():
            if only_trainable and not param.requires_grad:
                continue
            yield param

Doing this

net = nn.DataParallel(Net().cuda())
net(torch.rand(1, 10))

Throws this error

---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
<ipython-input-1-7bea0d51ab29> in <module>()
     18 
     19 net = nn.DataParallel(Net().cuda())
---> 20 net(torch.rand(1, 10))

~/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    475             result = self._slow_forward(*input, **kwargs)
    476         else:
--> 477             result = self.forward(*input, **kwargs)
    478         for hook in self._forward_hooks.values():
    479             hook_result = hook(self, input, result)

~/anaconda3/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py in forward(self, *inputs, **kwargs)
    121             return self.module(*inputs[0], **kwargs[0])
    122         replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
--> 123         outputs = self.parallel_apply(replicas, inputs, kwargs)
    124         return self.gather(outputs, self.output_device)
    125 

~/anaconda3/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py in parallel_apply(self, replicas, inputs, kwargs)
    131 
    132     def parallel_apply(self, replicas, inputs, kwargs):
--> 133         return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
    134 
    135     def gather(self, outputs, output_device):

~/anaconda3/lib/python3.6/site-packages/torch/nn/parallel/parallel_apply.py in parallel_apply(modules, inputs, kwargs_tup, devices)
     75         output = results[i]
     76         if isinstance(output, Exception):
---> 77             raise output
     78         outputs.append(output)
     79     return outputs

~/anaconda3/lib/python3.6/site-packages/torch/nn/parallel/parallel_apply.py in _worker(i, module, input, kwargs, device)
     51                 if not isinstance(input, (list, tuple)):
     52                     input = (input,)
---> 53                 output = module(*input, **kwargs)
     54             with lock:
     55                 results[i] = output

~/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    475             result = self._slow_forward(*input, **kwargs)
    476         else:
--> 477             result = self.forward(*input, **kwargs)
    478         for hook in self._forward_hooks.values():
    479             hook_result = hook(self, input, result)

~/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py in forward(self, *input)
     81             registered hooks while the latter silently ignores them.
     82         """
---> 83         raise NotImplementedError
     84 
     85     def register_buffer(self, name, tensor):

NotImplementedError: 

And doing this

net = Net().cuda()
# Freeze first layers parameters, i.e. only second layer is trainable
for param in net.module_list[0].parameters():
    param.requires_grad = False
net = nn.DataParallel(net)
net(torch.rand(1, 10))

Throws this error

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
<ipython-input-1-8d9b7061f99f> in <module>()
     21     param.requires_grad = False
     22 net = nn.DataParallel(net)
---> 23 net(torch.rand(1, 10))

~/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    475             result = self._slow_forward(*input, **kwargs)
    476         else:
--> 477             result = self.forward(*input, **kwargs)
    478         for hook in self._forward_hooks.values():
    479             hook_result = hook(self, input, result)

~/anaconda3/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py in forward(self, *inputs, **kwargs)
    120         if len(self.device_ids) == 1:
    121             return self.module(*inputs[0], **kwargs[0])
--> 122         replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
    123         outputs = self.parallel_apply(replicas, inputs, kwargs)
    124         return self.gather(outputs, self.output_device)

~/anaconda3/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py in replicate(self, module, device_ids)
    125 
    126     def replicate(self, module, device_ids):
--> 127         return replicate(module, device_ids)
    128 
    129     def scatter(self, inputs, kwargs, device_ids):

~/anaconda3/lib/python3.6/site-packages/torch/nn/parallel/replicate.py in replicate(network, devices, detach)
     50                     replica._parameters[key] = None
     51             else:
---> 52                 param_idx = param_indices[param]
     53                 for j in range(num_replicas):
     54                     replica = module_copies[j][i]

KeyError: Parameter containing:
tensor([[-0.3002,  0.2907,  0.1129,  0.2012,  0.3133, -0.1077,  0.0199, -0.0915,
         -0.1875, -0.0787],
        [-0.1535, -0.0093, -0.1195, -0.2870,  0.2770,  0.2447,  0.1371,  0.2554,
         -0.2400,  0.0050],
        [ 0.1053, -0.0462, -0.2816, -0.2469, -0.2198,  0.1078,  0.1210, -0.2257,
          0.2912,  0.0348],
        [-0.2850, -0.2684,  0.1115,  0.1451,  0.3048, -0.1432, -0.0334, -0.0985,
          0.0428, -0.1384],
        [-0.2661,  0.3154,  0.0290, -0.0202, -0.2558, -0.2669, -0.1606, -0.1784,
          0.0666,  0.1534],
        [ 0.1977,  0.0073, -0.0256,  0.1687,  0.2736, -0.2341, -0.0254, -0.1233,
         -0.1083,  0.1307],
        [-0.3091, -0.1185,  0.2292, -0.2904,  0.1551, -0.1073,  0.0901,  0.0815,
          0.0563, -0.1869],
        [ 0.1131,  0.1455, -0.1215, -0.2023, -0.1883, -0.1709, -0.0097,  0.2165,
         -0.1549,  0.0916],
        [-0.0114, -0.2245,  0.1819, -0.2465,  0.1708,  0.0840, -0.3031, -0.0886,
          0.2049,  0.1661],
        [-0.0540, -0.1216, -0.1092,  0.1388,  0.2321, -0.1198, -0.1509,  0.2244,
          0.0655,  0.2590]], device='cuda:0')

The parameter that produces a key error is net.module.module_list[0].weight. I’m guessing the parameters() function of DataParallel is called when replicating, while the overloaded parameters() function is called when creating a param index dict or something?

Your Net class doesn’t have the forward method implemented.
Could you add the method and try it again?

Was this code running on a single GPU?

Oh yea your right, I forgot to add the forward call in the conceptual class! Disregard the first error then.

The second error is what was originally my problem. I have tried running it on two GTX 1080ti as well as two different nvidia GPUs.
The code was originally running on a single GPU without using DataParallel. Using DataParallel with a single GPU has no effect, the code runs fine.

# Python 3.6
# Pytorch 4.1 installed via anaconda
import torch
from torch import nn

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.module_list = nn.ModuleList()
        self.module_list.append(nn.Linear(2, 2))
        self.module_list.append(nn.Linear(2, 2))
    
    def parameters(self, only_trainable=True):
        for param in self.module_list.parameters():
            if only_trainable and not param.requires_grad:
                continue
            yield param
    
    def forward(self, x):
        return x

net = Net().cuda()
for p in net.module_list[0].parameters():
    p.requires_grad = False
net = nn.DataParallel(net, [0, 1])
net(torch.rand(10, 2))

Produces the same error

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
<ipython-input-3-dc513201071a> in <module>()
     24     p.requires_grad = False
     25 net = nn.DataParallel(net, [0, 1])
---> 26 net(torch.rand(10, 2))

~/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    475             result = self._slow_forward(*input, **kwargs)
    476         else:
--> 477             result = self.forward(*input, **kwargs)
    478         for hook in self._forward_hooks.values():
    479             hook_result = hook(self, input, result)

~/anaconda3/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py in forward(self, *inputs, **kwargs)
    120         if len(self.device_ids) == 1:
    121             return self.module(*inputs[0], **kwargs[0])
--> 122         replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
    123         outputs = self.parallel_apply(replicas, inputs, kwargs)
    124         return self.gather(outputs, self.output_device)

~/anaconda3/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py in replicate(self, module, device_ids)
    125 
    126     def replicate(self, module, device_ids):
--> 127         return replicate(module, device_ids)
    128 
    129     def scatter(self, inputs, kwargs, device_ids):

~/anaconda3/lib/python3.6/site-packages/torch/nn/parallel/replicate.py in replicate(network, devices, detach)
     50                     replica._parameters[key] = None
     51             else:
---> 52                 param_idx = param_indices[param]
     53                 for j in range(num_replicas):
     54                     replica = module_copies[j][i]

KeyError: Parameter containing:
tensor([[-0.4050,  0.0905],
        [-0.1446, -0.5699]], device='cuda:0')

I still think the error is related to the fact that you are not replicating all Parameters, thus these are missing in the replicas.
If you only specify one GPU for DataParallel, the module will just be called without replication (line of code).

Maybe I’m not understanding your use case, but currently only the parameters requiring gradients will be replicated, which would create incomplete models.

There isn’t really any issue, I solved it by setting only_trainable=False in my parameters function so it would behave exactly like normal nn.Module.parameters() function. I was initially curious as to why it didn’t work before so I created some example code and forgot to implement forward which made me think there was some other issue (until you pointed it out). So I got answers to all my questions, thanks for the help!

Good to know and sorry for the confusion. :wink: