Custom filters in conv2d backward

Hi, I want to modify the convolutional filters only in the backward pass to use some fixed random ones (Feedback Alignment https://www.nature.com/articles/ncomms13276, https://arxiv.org/abs/1609.01596), so I wrote a custom module:

class Aconv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
        super(Aconv2d, self).__init__()
        if in_channels % groups != 0:
            raise ValueError('in_channels must be divisible by groups')
        if out_channels % groups != 0:
            raise ValueError('out_channels must be divisible by groups')

        kernel_size = _pair(kernel_size)
        stride = _pair(stride)
        padding = _pair(padding)
        dilation = _pair(dilation)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.transposed = False
        self.output_padding = _pair(0)
        self.groups = groups

        if self.transposed:
            self.weight = nn.Parameter(torch.Tensor(in_channels, out_channels // groups, *kernel_size))
        else:
            self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // groups, *kernel_size))
        if bias:
            self.bias = nn.Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)

        self.backward_weight = torch.Tensor(self.weight.size())
        self.reset_parameters()
        self.forward_weight = self.weight.data

    def reset_parameters(self):
        n = self.in_channels
        for k in self.kernel_size:
            n *= k
        stdv = 1. / math.sqrt(n)
        self.weight.data.uniform_(-stdv, stdv)
        self.backward_weight.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def forward(self, input):
        return F.conv2d(input, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)

    def switch_mode(self, mode):
        if mode == 'backward':
            self.forward_weight.copy_(self.weight)
            self.weight.data.copy_(self.backward_weight)
        elif mode == 'forward':
            self.weight.data.copy_(self.forward_weight)
        return

Then I iterate over my model layers and call switch_mode for each of these layers before the forward call and the backward call, to use the right weights.

With this an iteration takes ~45 seconds on my GPU, while with plain backprop it takes ~15 seconds.
I tried writing my custom autograd function using conv2d_input and conv2d_weight but it’s even slower (~240 seconds).

I was wondering if there is a way to do this more efficiently. I was thinking about backward hooks, but my understanding is that they’re executed after the backward of the module, while here I’d need them to be performed before the module backward.

Any ideas/suggestions?