Access cuDNN convolution grad input and grad weight computation function

Hi,

I am currently working on a problem of gradient pruning.

In the context of convolution computation, during the backpropagation, I need to alter the gradient of the weight (grad_weight) computation by slicing the gradient of the output (grad_output) before doing the actual gradient computation.

So to achieve this goal I need acces to the functions performing backpropagarion computation for convolution.

I don’t want to use the functions conv2d_input and conv2d_weight because they are slow. Binding the function convolution_backward does not solve my problem since I also compute the gradient of the input (grad_input) without slincing grad_output.

I want to use CUDA/cuDNN function to do the job.

My best guess so far was this example on github:

This example uses cudnn_convolution_backward_weight and cudnn_convolution_backward_input as function to compute the associated gradient.

But, from what I understand theses functions have been removed from the ATen API (for code rule compliance)and are no longer accessible (they are not anymore written inside native_functions.yaml)

My point is I really really need to have a fine control over the grad_input and grad_weight (and grad_bias) computation, So how I can do that?

More generaly, how to make accessible for python/torch binding functions that are not listed inside native_functions.yaml?

Sorry for the long post.
Thanks in advance for your answers and guidance.

These methods will still dispatch to cudnn as seen here:

device = "cuda"
x = torch.randn(1, 3, 24, 24, requires_grad=True, device=device)
w = torch.randn(3, 3, 3, 3, requires_grad=True, device=device)
grad_output = torch.randn_like(x)


activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA]
sort_by_keyword = device + "_time_total"

with profile(activities=activities, record_shapes=True) as prof:
    wgrad = torch.nn.grad.conv2d_weight(x, w.shape, grad_output, stride=1, padding=1)
print(prof.key_averages().table(sort_by=sort_by_keyword, row_limit=10))
# -------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
#                                                    Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
# -------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
#                              aten::convolution_backward        12.68%     212.324us        15.07%     252.310us     252.310us       4.608us       100.00%       4.608us       4.608us             1  
# void cudnn::cnn::wgrad2d_grouped_direct_kernel<false...         0.00%       0.000us         0.00%       0.000us       0.000us       4.608us       100.00%       4.608us       4.608us             1  
#                                         aten::new_empty         8.46%     141.694us        79.58%       1.332ms       1.332ms       0.000us         0.00%       0.000us       0.000us             1  
#                                             aten::empty        71.51%       1.197ms        71.51%       1.197ms     598.527us       0.000us         0.00%       0.000us       0.000us             2  
#                                            aten::expand         0.79%      13.165us         1.06%      17.703us      17.703us       0.000us         0.00%       0.000us       0.000us             1  
#                                        aten::as_strided         0.27%       4.538us         0.27%       4.538us       4.538us       0.000us         0.00%       0.000us       0.000us             1  
#                                                aten::to         0.10%       1.703us         0.10%       1.703us       1.703us       0.000us         0.00%       0.000us       0.000us             1  
#                                         cudaEventRecord         0.34%       5.731us         0.34%       5.731us       5.731us       0.000us         0.00%       0.000us       0.000us             1  
#                                   cudaStreamIsCapturing         0.06%       1.022us         0.06%       1.022us       1.022us       0.000us         0.00%       0.000us       0.000us             1  
#                                   cudaStreamGetPriority         0.06%       1.072us         0.06%       1.072us       1.072us       0.000us         0.00%       0.000us       0.000us             1  
# -------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
# Self CPU time total: 1.674ms
# Self CUDA time total: 4.608us

so the change in the interface should not affect the runtime.

Thanks a lot.

Your answer help me to understand a few things. Correct me if I 'm wrong.

For exemple conv2d_input call the function convolution_backward through this call:
(In pytorch/torch/nn/grap.py)

torch.ops.aten.convolution_backward(
        grad_output,
        input,
        weight,
        None,
        _pair(stride),
        _pair(padding),
        _pair(dilation),
        False,
        [0],
        groups,
        (True, False, False),
    )[0]

which is the main function for computing backward convolution, and I can call it this way , because it has been listed in the file pytorch/aten/src/ATen/native/native_functions.yaml.

So basically I can also call a function, listed in this file, for computing forward pass for batch normalisation:
torch.ops.aten._batch_norm_impl_index
Which gives me an output tensor , save_mean tensor, save_var tensor, reserve tensor (and a fifth one not usefull)
And compute the backward pass in a similar way it is done for the convolution with this function, using output , save_mean,save_var, and reserve .

(e.g for grad_input)

grad_input = torch.ops.aten.batch_norm_backward(
                grad_output,
                input,
                weight,
                running_mean,
                running_var,
                save_mean,
                save_var, False, eps,
                (True, False, False), reserve)[0]

Anyways thrank you.