Manual backward pass of standard Conv2D

Hi,

I was hoping that somebody could write out the manual backward pass for a conv2d layer.

So far I got everything working with the following code:

import torch
from torch.autograd.function import Function
from torch.autograd import gradcheck
torch.set_printoptions(threshold=10000)

class GradBatch_Conv2dFunction(Function):

	@staticmethod
	def forward(ctx, input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1, module=None):
		ctx.save_for_backward(input, weight, bias)
		ctx.module = module
		ctx.stride = stride

		ctx.padding = padding
		ctx.dilation = dilation
		ctx.groups = groups

		return F.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)

	@staticmethod
	def backward(ctx, grad_output):

		input, weight, bias = ctx.saved_tensors
		stride = ctx.stride
		padding = ctx.padding
		dilation = ctx.dilation
		groups = ctx.groups
		grad_input = grad_weight = grad_bias = None

		if ctx.needs_input_grad[0]:
			grad_input = torch.nn.grad.conv2d_input(input.shape, weight, grad_output, stride, padding, dilation, groups)
		if ctx.needs_input_grad[1]:
			grad_weight = torch.nn.grad.conv2d_weight(input, weight.shape, grad_output, stride, padding, dilation, groups)
		if bias is not None and ctx.needs_input_grad[2]:
			grad_bias = grad_output.sum((0, 2, 3)).squeeze(0)

		return grad_input, grad_weight, grad_bias, None, None, None, None, None


conv2d = GradBatch_Conv2dFunction.apply
BS, C_in, C_out = 1, 2, 1
kernel_size, groups =  3, 2

input = (torch.randn(BS, groups*C_in, 4, 4, dtype=torch.double, requires_grad=True), # input
	 torch.randn(groups*C_out, C_in, kernel_size, kernel_size, dtype=torch.double, requires_grad=True), # weight
	 # torch.randn(groups*C_out, dtype=torch.double, requires_grad=True), # bias
	 None, # bias
	 1, # stride
	 0, # padding
	 1, # dilation
	 groups, # groups
	 None # module
	 )

print(f"{torch.autograd.gradcheck(GradBatch_Conv2dFunction.apply, input, eps=1e-6, atol=1e-4)=}")

Yet, the moment I change groups to a value >1, I get an error and when I inspect the difference between the analytical and the numerical gradient there is a most peculiar structure in the difference.

diff.shape=torch.Size([36, 8])
[tensor([[False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False],
        [ True,  True,  True,  True, False, False, False, False],
        [ True,  True,  True,  True, False, False, False, False],
        [ True,  True,  True,  True, False, False, False, False],
        [ True,  True,  True,  True, False, False, False, False],
        [ True,  True,  True,  True, False, False, False, False],
        [ True,  True,  True,  True, False, False, False, False],
        [ True,  True,  True,  True, False, False, False, False],
        [ True,  True,  True,  True, False, False, False, False],
        [ True,  True,  True,  True, False, False, False, False],
        [False, False, False, False,  True,  True,  True,  True],
        [False, False, False, False,  True,  True,  True,  True],
        [False, False, False, False,  True,  True,  True,  True],
        [False, False, False, False,  True,  True,  True,  True],
        [False, False, False, False,  True,  True,  True,  True],
        [False, False, False, False,  True,  True,  True,  True],
        [False, False, False, False,  True,  True,  True,  True],
        [False, False, False, False,  True,  True,  True,  True],
        [False, False, False, False,  True,  True,  True,  True],
        [False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False]])]

The above boolean matrix was obtained by modifying the checkIfNumericalAnalyticAreClose(a, n, j, error_str='') to

def checkIfNumericalAnalyticAreClose(a, n, j, error_str=''):
            if not torch.allclose(a, n, rtol, atol):
                diff = (n - a).abs()
                print(f"{diff.shape=}")
                print([diff>atol])
                return fail_test(error_str + 'Jacobian mismatch for output %d with respect to input %d,\n'
                                 'numerical:%s\nanalytical:%s\n' % (i, j, n, a))

Am I computing the backward pass wrong for multiple groups?

Any help or pointing in a direction is much appreciated!
Thanks!

1 Like