A typical way to implement self_define conv2d can be:
class Conv2DFunctionCustom(Function):
@staticmethod
def forward(ctx, input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
ctx.save_for_backward(input, weight, bias)
ctx.stride, ctx.padding, ctx.dilation, ctx.groups = stride, padding, dilation, groups
output = torch.nn.functional.conv2d(input, weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
return output
@staticmethod
def backward(ctx, grad_output):
input, weight, bias = ctx.saved_tensors
stride, padding, dilation, groups = ctx.stride, ctx.padding, ctx.dilation, ctx.groups
if ctx.needs_input_grad[1]:
grad_input = torch.nn.grad.conv2d_input(input.shape, weight, grad_output, stride, padding, dilation, groups)
if ctx.needs_input_grad[2]:
grad_weight = torch.nn.grad.conv2d_weight(input, weight.shape, grad_output, stride, padding, dilation, groups)
return grad_input, grad_weight, grad_bias, grad_stride, grad_padding, grad_dilation, grad_groups
class Conv2DCustom(nn.Conv2d):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=False):
super(Conv2DCustom, self).__init__(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
def forward(self, x):
return Conv2DFunctionCustom.apply(x, self.weight,self.bias, self.stride, self.padding, self.dilation, self.groups)
However, this way get stuck when the tensor size is very large and will raise illegal cuda memory error in those cases—The error is something like:
RuntimeError: THCudaTensor sizes too large for THCDeviceTensor conversion at /data/yangwei/mount_52/source/pytorch/pytorch_v1.8.1/aten/src/THC/THCDeviceTensorUtils.cuh:71
and the error is caused by torch.nn.grad.conv2d_weight function.
So how to modify the codes above to deal with large size tensors in an efficient way?(still in need of self_define backward) .One possible solution may be to split the tensor and do the torch.nn.grad.conv2d_weight function several times, and then combine the splitted grad together, but this is still very slow.