@Lucca had you any success with your implementation? I am trying to do exactly the same to experiment with sparsifying the gradient. I would be very happy if you could share your solution, given that you found one.
I have the forwards part working thanks to Custom convolution layer - PyTorch Forums
but not sure what to put in @backwards
class SparseConv2d(torch.autograd.Function):
@staticmethod
# bias is an optional argument
def forward(ctx, input, weight, kernel_size, out_channels, dilation, padding, stride, info, bias=None):
# Save inputs in context-object for later use in backwards
ctx.save_for_backward(input, weight, bias) # these are differentiable
# Save non-differentiable argument(s)
ctx.info = info
# x [batchSize, in_channels, width, height]
width = ((x.shape[2] + 2*padding[0] - dilation[0]*(kernel_size - 1) - 1) // stride[0]) + 1
height= ((x.shape[3] + 2*padding[1] - dilation[1]*(kernel_size - 1) - 1) // stride[1]) + 1
windows = F.unfold(x, kernel_size=(kernel_size, kernel_size), padding=padding, dilation=dilation, stride=stride)
windows = windows.transpose(1, 2).contiguous().view(-1, x.shape[1], kernel_size*kernel_size)
windows = windows.transpose(0, 1)
output = torch.zeros([x.shape[0]*out_channels, width, height], dtype=torch.float32, device=device)
# Loop over channels
for channel in range(x.shape[1]):
for outChannel in range(out_channels):
res = torch.matmul(windows[channel], weight[outChannel][channel])
res = res.view(-1, width, height)
output[outChannel * res.shape[0] : (outChannel + 1) * res.shape[0]] += res
output = output.view(x.shape[0], out_channels, width, height)
#if bias is not None:
#output += bias.unsqueeze(0).expand_as(output)
return output
@staticmethod
def backward(ctx, grad_output):
input, weight, bias = ctx.saved_tensors
grad_input = grad_weight = grad_bias = None
# UNSURE WHAT TO PUT HERE
return grad_input, grad_weight, None, None, None, None, None, None, grad_bias