Conv2d backward pass using unfold+mm

Hi,
I am trying to implement the backward pass for Conv2d using unfold+mm. My first implementation used torch.nn.grad.conv2d_input, which works correctly. That function internally calls torch.conv_transpose2d. For this, I am using the code here to implement conv2d_transpose: ConvTranspose2d using unfold - #4 by santacml.

However, I am running into an issue and I’m not sure how to proceed. I have provided my code below, including the forward pass for comparison.

def forward(ctx, inp, weight, bias=None, stride=(1, 1), padding=(0, 0), dilation=(1, 1), groups=1):
	ctx.save_for_backward(inp, weight, bias)
	# Removed code showing stide, padding, dilation and groups being saved to ctx.

	(b, n_C_prev, n_H_prev, n_W_prev) = inp.shape
	(n_oC, n_iC, f, f) = weight.shape

	n_H = ((n_H_prev - f + (2 * padding[0])) // stride[0]) + 1
	n_W = ((n_W_prev - f + (2 * padding[1])) // stride[1]) + 1

	inp_padded = torch.nn.functional.pad(inp, (padding[0], padding[0], padding[1], padding[1]))
	inp_unf = torch.nn.functional.unfold(inp_padded, (f, f), padding=padding, stride=stride)

	inp_unf = inp_unf.transpose(1, 2)
	inp_unf = inp_unf.reshape(-1, inp_unf.shape[-1])
	weight_t = weight.view(weight.size(0), -1).t()

	out_unf = torch.matmul(inp_unf, weight_t)
	out_unf = out_unf.view(inp.shape[0], -1, out_unf.shape[-1])
	out_unf = out_unf.transpose(1, 2)
	output = out_unf.view(b, n_oC, n_H, n_W) + bias[:, None, None]
	return output

def backward(ctx, grad_output):  # grad_output.shape = [100,20,8,8]
	inp, weight, bias = ctx.saved_tensors # weight.shape = [20, 10, 5, 5]
	# Removed code which restores stride, padding, dilation & groups from forward.
	grad_input = None	

	if ctx.needs_input_grad[0]:
		# This code works correctly and trains the CNN. I am using it as a comparison for my code.
		# grad_input_check.shape = [100, 10, 12, 12]
		grad_input_check = torch.nn.grad.conv2d_input(inp.shape, weight, grad_output, stride, padding, dilation, groups)

		# Does the same thing as `_grad_input_padding` in `torch/nn/grad.py`		
		gi_padding = grad_input_padding(grad_output, inp.shape, stride, padding, ks, dilation)

		go_unf = torch.nn.functional.unfold(grad_output, ks, padding=gi_padding) # [100, 500, 16]
		go_unf = go_unf.transpose(1, 2)	# [100, 16, 500]

		w_rot = torch.rot90(weight, 2, [2, 3]) # [20, 10, 5, 5]
		w_rot = w_rot.view(w_rot.size(0), -1).t()	# [250, 20]

		grad_input = go_unf.mm(w_rot)

Ofcourse, I cannot multiply go_unf and w_rot with those shapes. I think the problem stems from the number of channels in grad_output. Working backwards, it seems I need a size of [100,250,72] for the output of unfold to get the right dimensions. Any help would be much appreciated! thank you!