Reduce GPU memory cost in building custom Convolution layer

Dear all,

I’m building a custom convolution layer but have the problem of too much GPU memory used.

The idea of my custom conv operation is to add a weight “s” in the normal convolutional operation.

For example, I have an input of 16 , 1 , 28 , 28 and my custom conv kernel 32 , 1 , 3 , 3, with stride = 1, padding = 1.

Then in each input image (1 , 1 , 28 , 28) I slide the kernel throughout the input image and thus do 28 * 28 custom conv operations. In each custom conv operation I do output = torch.sum(input * weight * s) where s is calculated by input and weight in some ways.

So my way of implementing my custom conv layer’s forward function is:

  1. padding and unfolding the input into a Tensor of (16 , 1 , 784 (784 = 28 * 28) , 9 (9 = 1 * 3 * 3))
  2. reshape the weight into Tensor of (32 , 1 , 9)
  3. calculate the additional weight ‘s’ using the unfolded input and reshaped kernel weight, ‘s’ is a Tensor of (16 , 32 , 784 , 9)
  4. point-wise multiplication between inputs (16 , 1 , 784 , 9), weight (32 , 1 , 9), and ‘s’ (16, 32, 784, 9), get the outputs (16, 32, 784, 9)
  5. sum out the last dimension, and fold the outputs in the third dimension (784) into two dimensions (28, 28 which are the width and height of the ouput images.) Now I get the output I want as a Tensor of (16, 32, 28, 28)

My custom conv layer has the desired performance as we expected, but the time cost and GPU memory cost are super high. I tried to use checkpointing to reduce the GPU memory cost, but it’s still super high.

So does anyone have suggestions about optimizing my implementations or helping me solve my problems? Also I’m interested in extending my custon conv in cpp, but I cannot find where the regular convolution operation is implemented. (I know this file but there seems to be no details about how conv operation is done. What I can find finally that is they used the thnn_conv2d_forward function.)

Sorry for my redundant English :P, and any infos are welcomed! :slight_smile:

Thank you!

The native implementations are mentioned in this post, which might be helpful in writing your custom implementation.

Thank you Ptrblck! :slight_smile: