Hi,
I am trying to write a customized backward function for ConvTranspose2d. I am using nn.grad as a reference, where it computes the gradients w.r.t weight in conv2d_weight. I thought if I only change torch.conv2d with nn.functional.conv_trasnpose2d it would do the job, but it didn’t.
So my question is how I can change conv2d_weight function (copied below from nn.grad for convenience) to be able to compute the gradient of ConvTranspose2d w.r.t its weights?
Thanks,
Tahereh
def conv2d_weight(input, weight_size, grad_output, stride=1, padding=0, dilation=1, groups=1):
r"""
Computes the gradient of conv2d with respect to the weight of the convolution.
Args:
input: input tensor of shape (minibatch x in_channels x iH x iW)
weight_size : Shape of the weight gradient tensor
grad_output : output gradient tensor (minibatch x out_channels x oH x oW)
stride (int or tuple, optional): Stride of the convolution. Default: 1
padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
"""
stride = _pair(stride)
padding = _pair(padding)
dilation = _pair(dilation)
in_channels = input.shape[1]
out_channels = grad_output.shape[1]
min_batch = input.shape[0]
grad_output = grad_output.contiguous().repeat(1, in_channels // groups, 1,
1)
grad_output = grad_output.contiguous().view(
grad_output.shape[0] * grad_output.shape[1], 1, grad_output.shape[2],
grad_output.shape[3])
input = input.contiguous().view(1, input.shape[0] * input.shape[1],
input.shape[2], input.shape[3])
grad_weight = torch.conv2d(input, grad_output, None, dilation, padding,
stride, in_channels * min_batch)
grad_weight = grad_weight.contiguous().view(
min_batch, grad_weight.shape[1] // min_batch, grad_weight.shape[2],
grad_weight.shape[3])
return grad_weight.sum(dim=0).view(
in_channels // groups, out_channels,
grad_weight.shape[2], grad_weight.shape[3]).transpose(0, 1).narrow(
2, 0, weight_size[2]).narrow(3, 0, weight_size[3])