How could I let one module by-pass the DDP wrapper


My model is defined in classical way, but it contains a self-defined operator:

class M(nn.Module):
    def __init__(self,):
        super(M, self).__init__()
        self.op = self_defined()

I am using DDP to train this model, which would sum up the grads of each gpu to compute the global grad and update the operators.

model = M()
model = nn.DistributedDataparallel(model)

However, in my case, it is my self_defined that is split into multi-gpu rather than the input data. Thus the grad computed within each gpu should not be summed up but gathered. Only this operator works in this way and all other operators works in the way DDP is designed.

How could I let nn.DistributedDataParallel by pass this operator ?

DDP is designed for data parallel. AFAIK, your example is more like a model parallel, which I think is impossible to be properly wrapped up with DDP. Correct me if I’m wrong.