Overwrite forward in layers to work with Data Parallel

Hello,
I want to create a function to overwrite the forward and backward pass of a nn.Module. E.g. I load a ResNet or any other network, and I automatically change the forward and backward pass of all the layers by a custom Autograd function.

The code I have now is working in CPU and one GPU but not working when I extend to DataParallel.
It is distributing data and the model in different GPUs (Exactly like here: Issue for DataParallel · Issue #8637 · pytorch/pytorch · GitHub) and (DataParallel on modules with dynamically overwritten forwards)

Also I don’t know if that is the best way of overwriting the forward/backward pass.

I define a function to be applied to every model layer:

def override_backward(layer):
    if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.ConvTranspose2d):
        def forward_conv(x):
            if layer.bias is None:
                return Conv2dFA.apply(x,
                                      layer.weight,
                                      layer.weight_fa,
                                      None,
                                      None,
                                      layer.stride,
                                      layer.padding,
                                      layer.dilation,
                                      layer.groups)
            else:
                return Conv2dFA.apply(x,
                                      layer.weight,
                                      layer.weight_fa,
                                      layer.bias,
                                      layer.bias_fa,
                                      layer.stride,
                                      layer.padding,
                                      layer.dilation,
                                      layer.groups)
        layer.forward = forward_conv

The function Conv2dFA is:

class Conv2dFA(autograd.Function):
    @staticmethod
    def forward(context, input, kernels, kernels_fa, bias, bias_fa, stride, padding, dilation, groups):
        context.stride, context.padding, context.dilation, context.groups = stride, padding, dilation, groups
        context.save_for_backward(input, kernels, kernels_fa, bias, bias_fa)
        output = torch.nn.functional.conv2d(input,
                                            kernels,
                                            bias=bias,
                                            stride=stride,
                                            padding=padding,
                                            dilation=dilation,
                                            groups=groups)
        return output

    @staticmethod
    def backward(context, grad_output):
        input, kernels, kernels_fa, bias, bias_fa = context.saved_tensors
        grad_input = grad_kernels = grad_kernels_fa = grad_bias = grad_bias_fa = None

        if context.needs_input_grad[0]:
           grad_input = torch.nn.grad.conv2d_input(input_size=input.shape,
                                                   weight=kernels_fa,
                                                   grad_output=grad_output,
                                                   stride=context.stride,
                                                   padding=context.padding,
                                                   dilation=context.dilation,
                                                   groups=context.groups)

        if context.needs_input_grad[1]:
            grad_kernels = torch.nn.grad.conv2d_weight(input=input,
                                                       weight_size=kernels_fa.shape,
                                                       grad_output=grad_output,
                                                       stride=context.stride,
                                                       padding=context.padding,
                                                       dilation=context.dilation,
                                                       groups=context.groups)

        if bias is not None and context.needs_input_grad[3]:
            grad_bias = grad_output.sum(0).sum(2).sum(1)

        # add the input in the stride gradient which is useless
        # return grad_input, grad_kernels, grad_kernels_fa, grad_bias, grad_bias_fa, stride, padding, dilation, groups
        return grad_input, grad_kernels, grad_kernels_fa, grad_bias, grad_bias_fa, None, None, None, None

And I apply it like this:

model_fa = resnet.resnet18()
model_fa.apply(override_backward)

Is there a way of doing the dynamic overwrite forward that will work with Data Parallel? I don’t want to create custom classes as I want this method to be applicable to every neural net.

Thanks in advance.

this post showed how to overwrite forward in layers How can I replace the forward method of a predefined torchvision model with my customized forward function? - #6 by Philipp_Friebertshau

hope it helps