Hi,
I was hoping that somebody could write out the manual backward pass for a conv2d layer.
So far I got everything working with the following code:
import torch
from torch.autograd.function import Function
from torch.autograd import gradcheck
torch.set_printoptions(threshold=10000)
class GradBatch_Conv2dFunction(Function):
@staticmethod
def forward(ctx, input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1, module=None):
ctx.save_for_backward(input, weight, bias)
ctx.module = module
ctx.stride = stride
ctx.padding = padding
ctx.dilation = dilation
ctx.groups = groups
return F.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
@staticmethod
def backward(ctx, grad_output):
input, weight, bias = ctx.saved_tensors
stride = ctx.stride
padding = ctx.padding
dilation = ctx.dilation
groups = ctx.groups
grad_input = grad_weight = grad_bias = None
if ctx.needs_input_grad[0]:
grad_input = torch.nn.grad.conv2d_input(input.shape, weight, grad_output, stride, padding, dilation, groups)
if ctx.needs_input_grad[1]:
grad_weight = torch.nn.grad.conv2d_weight(input, weight.shape, grad_output, stride, padding, dilation, groups)
if bias is not None and ctx.needs_input_grad[2]:
grad_bias = grad_output.sum((0, 2, 3)).squeeze(0)
return grad_input, grad_weight, grad_bias, None, None, None, None, None
conv2d = GradBatch_Conv2dFunction.apply
BS, C_in, C_out = 1, 2, 1
kernel_size, groups = 3, 2
input = (torch.randn(BS, groups*C_in, 4, 4, dtype=torch.double, requires_grad=True), # input
torch.randn(groups*C_out, C_in, kernel_size, kernel_size, dtype=torch.double, requires_grad=True), # weight
# torch.randn(groups*C_out, dtype=torch.double, requires_grad=True), # bias
None, # bias
1, # stride
0, # padding
1, # dilation
groups, # groups
None # module
)
print(f"{torch.autograd.gradcheck(GradBatch_Conv2dFunction.apply, input, eps=1e-6, atol=1e-4)=}")
Yet, the moment I change groups to a value >1, I get an error and when I inspect the difference between the analytical and the numerical gradient there is a most peculiar structure in the difference.
diff.shape=torch.Size([36, 8])
[tensor([[False, False, False, False, False, False, False, False],
[False, False, False, False, False, False, False, False],
[False, False, False, False, False, False, False, False],
[False, False, False, False, False, False, False, False],
[False, False, False, False, False, False, False, False],
[False, False, False, False, False, False, False, False],
[False, False, False, False, False, False, False, False],
[False, False, False, False, False, False, False, False],
[False, False, False, False, False, False, False, False],
[ True, True, True, True, False, False, False, False],
[ True, True, True, True, False, False, False, False],
[ True, True, True, True, False, False, False, False],
[ True, True, True, True, False, False, False, False],
[ True, True, True, True, False, False, False, False],
[ True, True, True, True, False, False, False, False],
[ True, True, True, True, False, False, False, False],
[ True, True, True, True, False, False, False, False],
[ True, True, True, True, False, False, False, False],
[False, False, False, False, True, True, True, True],
[False, False, False, False, True, True, True, True],
[False, False, False, False, True, True, True, True],
[False, False, False, False, True, True, True, True],
[False, False, False, False, True, True, True, True],
[False, False, False, False, True, True, True, True],
[False, False, False, False, True, True, True, True],
[False, False, False, False, True, True, True, True],
[False, False, False, False, True, True, True, True],
[False, False, False, False, False, False, False, False],
[False, False, False, False, False, False, False, False],
[False, False, False, False, False, False, False, False],
[False, False, False, False, False, False, False, False],
[False, False, False, False, False, False, False, False],
[False, False, False, False, False, False, False, False],
[False, False, False, False, False, False, False, False],
[False, False, False, False, False, False, False, False],
[False, False, False, False, False, False, False, False]])]
The above boolean matrix was obtained by modifying the checkIfNumericalAnalyticAreClose(a, n, j, error_str='')
to
def checkIfNumericalAnalyticAreClose(a, n, j, error_str=''):
if not torch.allclose(a, n, rtol, atol):
diff = (n - a).abs()
print(f"{diff.shape=}")
print([diff>atol])
return fail_test(error_str + 'Jacobian mismatch for output %d with respect to input %d,\n'
'numerical:%s\nanalytical:%s\n' % (i, j, n, a))
Am I computing the backward pass wrong for multiple groups?
Any help or pointing in a direction is much appreciated!
Thanks!