How to write self_define conv2d in an efficient way?

A typical way to implement self_define conv2d can be:

class Conv2DFunctionCustom(Function):
    @staticmethod
    def forward(ctx, input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
        
        ctx.save_for_backward(input, weight, bias)
        ctx.stride, ctx.padding, ctx.dilation, ctx.groups = stride, padding, dilation, groups
        output = torch.nn.functional.conv2d(input, weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
        return output
    @staticmethod
    def backward(ctx, grad_output):

        input, weight, bias = ctx.saved_tensors
        stride, padding, dilation, groups = ctx.stride, ctx.padding, ctx.dilation, ctx.groups
                
        if ctx.needs_input_grad[1]:
            grad_input = torch.nn.grad.conv2d_input(input.shape, weight, grad_output, stride, padding, dilation, groups)
        if ctx.needs_input_grad[2]:
            grad_weight = torch.nn.grad.conv2d_weight(input, weight.shape, grad_output, stride, padding, dilation, groups)


        return grad_input, grad_weight, grad_bias, grad_stride, grad_padding, grad_dilation, grad_groups


class Conv2DCustom(nn.Conv2d):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=False):
        super(Conv2DCustom, self).__init__(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
        
    def forward(self, x):        
        
        return Conv2DFunctionCustom.apply(x, self.weight,self.bias, self.stride, self.padding, self.dilation, self.groups)

However, this way get stuck when the tensor size is very large and will raise illegal cuda memory error in those cases—The error is something like:

RuntimeError: THCudaTensor sizes too large for THCDeviceTensor conversion at /data/yangwei/mount_52/source/pytorch/pytorch_v1.8.1/aten/src/THC/THCDeviceTensorUtils.cuh:71

and the error is caused by torch.nn.grad.conv2d_weight function.

So how to modify the codes above to deal with large size tensors in an efficient way?(still in need of self_define backward) .One possible solution may be to split the tensor and do the torch.nn.grad.conv2d_weight function several times, and then combine the splitted grad together, but this is still very slow.

torch.nn.grad isn’t really maintained anymore actually. What is your application for this?

I need to treat the calculation process of grad_input and grad_weight differently for conv2d layer, so I need to write self_define conv2d_backward. For example, I need to impliment something like:

grad_output_1 = grad_output*A +B
grad_input = torch.nn.grad.conv2d_input(input.shape, weight, grad_output_1 , stride, padding, dilation, groups)


grad_output_2 = grad_output*C +D
grad_weight = torch.nn.grad.conv2d_weight(input, weight.shape, grad_output_2 , stride, padding, dilation, groups)

What version of torch are you using? Can you try again with the latest version?
If you still run into the issue, maybe reduce the batch size?

I am using Pytorch- 1.12.1
Will higher version fix this problem?

The implementation changed between 1.12.1 and 1.13. It won’t necessarily solve your issue, but it may be possible.