cuDNN-accelerated autoencoder: computing backward of cudnn_convolution.convolution_backward_input

Hi!

I have prototyped a convolutional autoencoder with two distinct sets of weights for the encoder (with parameters w_f) and for the decoder (w_b). I have naturally used nn.Conv2d and nn.ConvTranspose2d to build the encoder and decoder respectively. The rough context of study is on the one hand to learn w_f so that it minimizes a loss function defined at the output, and on the other hand to learn w_b so that it matches w_f^T. The forward pass typically reads as a convolution followed by a transpose convolution. So far, everything works out well.

Now I want to accelerate this code using the cuDNN-accelerated C++ functions cudnn_convolution, cudnn_convolution_backward_weight and cudnn_convolution_backward_input (GitHub - jordan-g/PyTorch-cuDNN-Convolution: PyTorch extension enabling direct access to cuDNN-accelerated C++ convolution functions.). I am writing my own conv layer (inheriting from torch.autograd.Function), defining its forward and backward methods by hand. In the forward method, we have a cudnn_convolution operation (parametrized by w_f) followed by a cudnn_convolution_backward_input operation (parametrized by w_b). In the backward method, in order to compute the gradient of the loss with respect to w_b, I therefore need to backpropagate through cudnn_convolution_backward_input. I thought I could simply use cudnn_convolution_backward_weight, but it does not work. What is the good way to proceed to compute the gradient of w_b? I can provide the code of a concrete toy problem reproducing this issue if necessary. Many thanks for reading!

Hi,

That does sounds like the right strategy.
In pseudo code, it would look like:

@staticmethod
def forward(ctx, inp, weight):
  ctx.save_for_backward(inp, weight)
  return cudnn_convolution(inp, weight)

@staticmethod
def backward(ctx, grad_out):
    inp, weight = ctx.saved_tensors
    return cudnn_convolution_backward_input(grad_out, weight), cudnn_convolution_backward_weight(grad_out, inp)

Doesn’t that work for you?

Dear Alban,

Thank you very much for your time!
Let me be more precise about the setting with the following pseudo-code:

    @staticmethod
    def forward(ctx, input, w_f, bias_f, w_b, stride, padding):
        
            y = cudnn_convolution.convolution(input, w_f, bias_f, 
                                              stride, padding,
                                              (1, 1), 1, False, False)
            
            r = cudnn_convolution.convolution_backward_input(input.shape, w_b.data, y, 
                                                             stride, padding, 
                                                             (1, 1), 1, False, False, False)
            
            ctx.save_for_backward(y, w_b)
            ctx.stride = stride
            ctx.padding = padding
            
            return r        
        
    @staticmethod
    def backward(ctx, grad_r):
        w_b_learning = ctx.w_b_learning
        stride, padding = ctx.stride, ctx.padding

        y, w_b = ctx.saved_tensors    
        
        grad_w_b = cudnn_convolution.convolution_backward_weight(..?..)
        grad_y = cudnn_convolution.convolution_backward_input(..?..)
        grad_w_f = cudnn_convolution.convolution_backward_weight(input, w_f.shape, grad_y, stride, padding, (1, 1), 1, False, False, False)
        grad_bias_f = torch.sum(grad_y, dim=[0, 2, 3]).squeeze(0)   
        
        return None, grad_w_f, grad_bias_f, grad_w_b, None, None

My question is therefore the following : how could I compute grad_w_b and grad_y? Can we also use cudnn_convolution.convolution_backward_input and cudnn_convolution.convolution_backward_weight with the “right” parameters? Or should we use another cpp function? It would be incredibly helpful, I have looked everywhere on the web and I could not find an answer :-(.

Why do you do the backward wrt to the input in the forward? That should only happen during the backward.

Hi Alban,

I want to do in the forward pass something strictly equivalent to ConvTranpose2d (like we would do to define the forward pass of a convolutional autoencoder) but using these cuDNN accelerated functions. Which is why I thought in the first place to use cudnn_convolution.convolution_backward_input within the forward pass.

To put it differently: based on the cudnn_convolution.cpp attached to his URL (PyTorch-cuDNN-Convolution/cudnn_convolution.cpp at master · ernoult/PyTorch-cuDNN-Convolution · GitHub), what should I add in this cpp file to define cudnn_convolution_transpose, cudnn_convolution_transpose_backward_weight and cudnn_convolution_transpose_backward_input? I should mention that I don’t know anything about cpp :confused:

In that case, you don’t want to do the convolution itself.
You can o that trick like we do it in pytorch core here: pytorch/ConvShared.cpp at a756a9e553ad23042fa39c6452d17aba72911abb · pytorch/pytorch · GitHub

Thank you so much Alban! This is exactly what I was looking for! I will try it out and let you know if it worked :slight_smile:

1 Like

Hi again,

I am really sorry to bother you again, but I think I am failing to do something that must be relatively simple. Along the lines of the cudnn_convolution.cpp file of this repo ( PyTorch-cuDNN-Convolution/cudnn_convolution.cpp at master · ernoult/PyTorch-cuDNN-Convolution (github.com), I would like to create a wrapper around the C++ accelerated functions for convolution_tranpose and convolution_transpose_backward_weight. In cudnn_convolution.cpp, I have therefore added:

at::Tensor convolution_transpose(
    const at::Tensor& input,
    const at::Tensor& weight,
    const at::Tensor& bias,
    c10::ArrayRef<int64_t> stride,
    c10::ArrayRef<int64_t> padding,
    c10::ArrayRef<int64_t> output_padding,
    c10::ArrayRef<int64_t> dilation,
    int64_t groups,
    bool benchmark,
    bool deterministic) {

    return at::cudnn_convolution_transpose(
        input,
        weight,
        bias,
        padding,
        output_padding,
        stride,
        dilation,
        groups,
        benchmark,
        deterministic); 
}


at::Tensor convolution_transpose_backward_weight(
    const at::Tensor& input,
    c10::ArrayRef<int64_t> weight_size,
    const at::Tensor& grad_output,
    c10::ArrayRef<int64_t> stride,
    c10::ArrayRef<int64_t> padding,
    c10::ArrayRef<int64_t> dilation,
    int64_t groups,
    bool benchmark,
    bool deterministic,
    bool allow_tf32) {

    return at::cudnn_convolution_transpose_backward_weight(
        weight_size,
        grad_output,
        input,
        padding,
        stride,
        dilation,
        groups,
        benchmark,
        deterministic,
        allow_tf32);
}

Then, I load this cpp modules from a python script through:

cudnn_convolution = load(name="cudnn_convolution", sources=["cudnn_convolution.cpp"], verbose=True)

Finally, I try to use transpose convolution with the following script:

weight = torch.zeros(64, 3, 5, 5).to('cuda')
bias   = torch.zeros(64).to('cuda')

stride   = (2, 2)
padding  = (0, 0)
output_padding  = (0, 0)
dilation = (1, 1)
groups   = 1
dummy_input = torch.zeros(128, 64, 14, 14).to('cuda')
output = cudnn_convolution.convolution_transpose(dummy_input, weight, bias, stride, padding, output_padding, dilation, groups, False, False)

but it throws the following error message:

AttributeError: module 'cudnn_convolution' has no attribute 'convolution_transpose'

Do you have any idea of what happens? I am a bit desperate :confused: