Hi, I want to modify the convolutional filters only in the backward pass to use some fixed random ones (Feedback Alignment https://www.nature.com/articles/ncomms13276, https://arxiv.org/abs/1609.01596), so I wrote a custom module:
class Aconv2d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
super(Aconv2d, self).__init__()
if in_channels % groups != 0:
raise ValueError('in_channels must be divisible by groups')
if out_channels % groups != 0:
raise ValueError('out_channels must be divisible by groups')
kernel_size = _pair(kernel_size)
stride = _pair(stride)
padding = _pair(padding)
dilation = _pair(dilation)
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.dilation = dilation
self.transposed = False
self.output_padding = _pair(0)
self.groups = groups
if self.transposed:
self.weight = nn.Parameter(torch.Tensor(in_channels, out_channels // groups, *kernel_size))
else:
self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // groups, *kernel_size))
if bias:
self.bias = nn.Parameter(torch.Tensor(out_channels))
else:
self.register_parameter('bias', None)
self.backward_weight = torch.Tensor(self.weight.size())
self.reset_parameters()
self.forward_weight = self.weight.data
def reset_parameters(self):
n = self.in_channels
for k in self.kernel_size:
n *= k
stdv = 1. / math.sqrt(n)
self.weight.data.uniform_(-stdv, stdv)
self.backward_weight.uniform_(-stdv, stdv)
if self.bias is not None:
self.bias.data.uniform_(-stdv, stdv)
def forward(self, input):
return F.conv2d(input, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
def switch_mode(self, mode):
if mode == 'backward':
self.forward_weight.copy_(self.weight)
self.weight.data.copy_(self.backward_weight)
elif mode == 'forward':
self.weight.data.copy_(self.forward_weight)
return
Then I iterate over my model layers and call switch_mode
for each of these layers before the forward call and the backward call, to use the right weights.
With this an iteration takes ~45 seconds on my GPU, while with plain backprop it takes ~15 seconds.
I tried writing my custom autograd function using conv2d_input
and conv2d_weight
but it’s even slower (~240 seconds).
I was wondering if there is a way to do this more efficiently. I was thinking about backward hooks, but my understanding is that they’re executed after the backward of the module, while here I’d need them to be performed before the module backward.
Any ideas/suggestions?