F.grad.conv3d_weight High Memory Usage

Hi,

I have been using torch.nn.functional.grad.conv3d_weight to compute the gradient of the convolution kernel, but I have noticed that it uses much more memory than whatever method Autograd is calling. For context, I am working on implementing a form of reversible networks. These networks do not need to store activations in the forward pass, so I am avoiding the use of PyTorch’s autograd and writing my own backward method to take advantage of the memory efficiencies of these architectures.

More specifically the line:

grad_output = grad_output.repeat(1, in_channels // groups, 1, 1, 1)

creates a very large intermediate tensor when in_channels is large.

I noticed this during training, so I have seperated it out into this simple script in hopes of finding a better way to do this.

import torch
import torch.nn.functional as F

import argparse

def byte2mb(x):
    return x*1e-6

def mem_report():
    for obj in gc.get_objects():
        if torch.is_tensor(obj):
            print(type(obj), obj.size())

def check_mem(report=False):
    mem_alloc = byte2mb(torch.cuda.memory_allocated())
    mem_cached = byte2mb(torch.cuda.memory_cached())

    if report:
        mem_report()
    
    print('Mem Alloc: %6.2f, Mem Cached: %6.2f' % (mem_alloc, mem_cached))

if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument('--no_autograd', dest='no_autograd', action='store_true')
    args = parser.parse_args()

    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    kernel_size = 3

    # Define an input tensor and a convolution kernel
    (N, C, D, W, H) = (1,64,64,128,128)
    x = torch.rand(N, C, D, W, H).to(device)
    K = torch.rand(C, C, kernel_size, kernel_size, kernel_size, requires_grad=True).to(device)

    # Compute gradients without auto grad
    if args.no_autograd:
        with torch.no_grad():
            y = F.conv3d(x, K, padding=kernel_size//2)
            dy = torch.ones_like(y)
            dK = F.grad.conv3d_weight(x, K.shape, dy, padding=kernel_size//2)
    else:
        # Compute gradients with auto grad
        y = F.conv3d(x, K, padding=kernel_size//2)
        loss = y.sum()
        loss.backward()

    check_mem()

When ran, I get:

$ python test_conv3d_backward.py
Mem Alloc: 537.31, Mem Cached: 807.40

$ python test_conv3d_backward.py --no_autograd
Traceback (most recent call last):
  File "test_conv3d_backward.py", line 42, in <module>
    dK = F.grad.conv3d_weight(x, K.shape, dy, padding=kernel_size//2)
  File ".../python3.7/site-packages/torch/nn/grad.py", line 287, in conv3d_weight
    grad_output = grad_output.repeat(1, in_channels // groups, 1, 1, 1)
RuntimeError: CUDA out of memory. Tried to allocate 16.00 GiB (GPU 0; 10.76 GiB total capacity; 768.42 MiB already allocated; 9.07 GiB free; 770.00 MiB reserved in total by PyTorch)

This test is not very specific, but clearly Autograd is calling a different method that doesn’t require creating the 16GB intermediate tensor.

Is there anyway for me to call that method directly? Ideally I would like to call the ConvBackward method that AutoGrad is using, however I am sure that this is done in C++. A few years ago apaszke said that it wasn’t possible, but there have been some more recent disscussions where people wrote their own C++ extension.

Thanks

I ended up finding a solution by following this repo to create a C++ extension and wrapping it so that it behaves similarly to torch.nn.functional.grad.conv3d_weight. So far it solves my memory issue, and I am hoping I will see a bit of a speed up as well. I will update this post if I run into any issues and how/if I solve them.

In my research I found a long line of discussion on this topic over the years, where various people have been trying to call the cudnn_convolution family of functions directly from PyTorch. If anyone has some insight into why this isn’t exposed in PyTorch I would be really curious to hear it. This seems like such a core functionality for a deep learning framework.

For the curious, here is what I ended up doing:

import torch
from torch.utils.cpp_extension import load
from torch.nn.modules.utils import _triple

# load the PyTorch extension
cudnn_convolution = load(name="cudnn_convolution", sources=["src/cudnn_convolution.cpp"], verbose=True)

def conv3d_weight(input, weight_size, grad_output, stride=1, padding=0, dilation=1, groups=1):
    r"""
    Computes the gradient of conv3d with respect to the weight of the convolution.
    Args:
        input: input tensor of shape (minibatch x in_channels x iT x iH x iW)
        weight_size : Shape of the weight gradient tensor
        grad_output : output gradient tensor (minibatch x out_channels x oT x oH x oW)
        stride (int or tuple, optional): Stride of the convolution. Default: 1
        padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
        dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
        groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
    Examples::
        >>> input = torch.randn(2, 8, 10, 10, 20, requires_grad=True)
        >>> weight = torch.randn(4, 8, 2, 3, 3, requires_grad=True)
        >>> output = F.conv3d(input, weight)
        >>> grad_output = torch.randn(output.shape)
        >>> grad_weight = torch.autograd.grad(output, weight, grad_output)
        >>> F.grad.conv3d_weight(input, weight.shape, grad_output)
    """
    stride = _triple(stride)
    padding = _triple(padding)
    dilation = _triple(dilation)

    grad_weight = cudnn_convolution.convolution_backward_weight(
        input, 
        weight_size, 
        grad_output, 
        stride, 
        padding, 
        dilation, 
        groups, 
        False, # Benchmark
        False  # Deterministic
    )

    return grad_weight

if __name__ == "__main__":

    # create dummy input, convolutional weights and bias
    input  = torch.zeros(1, 3, 8, 16, 16).to('cuda')
    weight = torch.zeros(3, 3, 3, 3, 3).to('cuda')
    bias   = torch.zeros(3).to('cuda')

    # create dummy gradient w.r.t. the output
    grad_output = torch.rand_like(input)

    # compute the gradient w.r.t. the weights and input
    grad_weight = conv3d_weight(input, weight.shape, grad_output, stride=1, padding=1)