Custom implementation of Conv2d - Inconsistent Behavior

Hello all,

I’m trying to implement a custom implementation of Conv2d following the docs in this link. I have a code that works, and a code that doesn’t work; and the difference between the two is just one line. I’m sure my question is a lack of knowledge on how the autograd system works, but I hope someone can tell me where to troubleshoot.

This code works, and produces the same accuracy of any model trained using it. It is just ~2x slower because the implementation of forward and backward are in Python.

import torch
from torch.nn.functional import conv2d


class Conv2d(torch.autograd.Function):

    @staticmethod
    def forward(ctx, input, weights, bias=None, stride=1, padding=0, dilation=1, groups=1):
        ctx.stride, ctx.padding, ctx.dilation, ctx.groups = stride, padding, dilation, groups
        ctx.save_for_backward(input, weights, bias if bias else None)
        return conv2d(input, weights, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=1)

    @staticmethod
    def backward(ctx, grad_output):
        stride, padding, dilation, groups = ctx.stride, ctx.padding, ctx.dilation, ctx.groups
        input, weights, bias = ctx.saved_tensors
        grad_input = grad_weight = grad_bias = None

        if ctx.needs_input_grad[0]:
            grad_input = torch.nn.grad.conv2d_input(input.shape, weights, grad_output, 
                                                    stride=stride, padding=padding,
                                                    dilation=dilation, groups=groups)
        if ctx.needs_input_grad[1]:
            grad_weight = torch.nn.grad.conv2d_weight(input, weights.shape, grad_output,
                                                      stride=stride, padding=padding,
                                                      dilation=dilation, groups=groups)
        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(0)
        
        return grad_input, grad_weight, grad_bias, None, None, None, None

Now, this code tries to save the activations in a different format. The reason for this format is to enable a research use case that we are going to expand on. The below code doesn’t work. I will post the error at the end of this post.

import torch
from torch.nn.functional import conv2d

from my_package import MyTensor     # this is implemented in C++ and saves a tensor in a different format

class Conv2d(torch.autograd.Function):

    @staticmethod
    def forward(ctx, input, weights, bias=None, stride=1, padding=0, dilation=1, groups=1):
        ctx.stride, ctx.padding, ctx.dilation, ctx.groups = stride, padding, dilation, groups
        ctx.input = MyTensor(input)    # saving as a context variable because save_for_backward only accepts SavedVariable data type. The rest of the code in this function works 
        ctx.save_for_backward(weights, bias if bias else None)
        return conv2d(input, weights, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=1)

    @staticmethod
    def backward(ctx, grad_output):
        stride, padding, dilation, groups = ctx.stride, ctx.padding, ctx.dilation, ctx.groups
        input = ctx.input.get_dense()      # get_dense() is a function that returns the original tensor (`input`) -- the first input in the forward function
        weights, bias = ctx.saved_tensors
        grad_input = grad_weight = grad_bias = None

        if ctx.needs_input_grad[0]:
            # the code fails in this line
            grad_input = torch.nn.grad.conv2d_input(input.shape, weights, grad_output, 
                                                    stride=stride, padding=padding,
                                                    dilation=dilation, groups=groups)
        if ctx.needs_input_grad[1]:
            # the code fails in this line if we swapped the order of the if statements
            grad_weight = torch.nn.grad.conv2d_weight(input, weights.shape, grad_output,
                                                      stride=stride, padding=padding,
                                                      dilation=dilation, groups=groups)
        if bias is not None and ctx.needs_input_grad[2]:
            # the code does NOT fail in this line if we swapped the order of the if statements
            grad_bias = grad_output.sum(0)
        
        return grad_input, grad_weight, grad_bias, None, None, None, None

So, the code fails in conv2d_input and conv2d_weight (and I will post the error later).

Now, to fix the above code, I only add one line in the forward function:

import torch
from torch.nn.functional import conv2d

from my_package import MyTensor

class Conv2d(torch.autograd.Function):

    @staticmethod
    def forward(ctx, input, weights, bias=None, stride=1, padding=0, dilation=1, groups=1):
        ctx.any_variable_name = input      # (this line fixes it) literally, if I save the original input to any variable name, the backpropagation code works
        ctx.stride, ctx.padding, ctx.dilation, ctx.groups = stride, padding, dilation, groups
        ctx.input = MyTensor(input)    # saving as a context variable because save_for_backward only accepts SavedVariable data type. The rest of the code in this function works 
        ctx.save_for_backward(weights, bias if bias else None)
        return conv2d(input, weights, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=1)

    @staticmethod
    def backward(ctx, grad_output):
        stride, padding, dilation, groups = ctx.stride, ctx.padding, ctx.dilation, ctx.groups
        input = ctx.input.get_dense()
        weights, bias = ctx.saved_tensors
        grad_input = grad_weight = grad_bias = None

        if ctx.needs_input_grad[0]:
            grad_input = torch.nn.grad.conv2d_input(input.shape, weights, grad_output, 
                                                    stride=stride, padding=padding,
                                                    dilation=dilation, groups=groups)
        if ctx.needs_input_grad[1]:
            grad_weight = torch.nn.grad.conv2d_weight(input, weights.shape, grad_output,
                                                      stride=stride, padding=padding,
                                                      dilation=dilation, groups=groups)
        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(0)
        
        return grad_input, grad_weight, grad_bias, None, None, None, None

With this story, I suspected that I have a lack of understanding for how the autograd system work, and that the lines that fail point to the input variable somehow in the code, such that when it is not saved in the ctx under any name, the functions conv2d_input and conv2d_weight fail.

But I dug into the code of these two functions and they seem to be just dependent on the function inputs.

My question now is: what is the missing piece that I’m not understanding that makes the flow works when I save the input in the ctx under any name, but never retrieve it back in the backward pass.


For the additional error context, I’m testing the above on a resnet18 network that uses my Conv2d function instead of the built-in functions. My code looks like this (for testing purposes):

model = models.resnet18()
inputs = torch.rand((64, 3, 224, 224))
labels = torch.randint(2, (64, ))

out = model(inputs)       # this always works
loss = nll_loss(out, labels)
loss.backward()            # it fails here

And the error in the code snippet that doesn’t work is:

test/test_resnet.py:21: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
.venv/lib/python3.8/site-packages/torch-1.7.1-py3.8-linux-x86_64.egg/torch/tensor.py:221: in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
.venv/lib/python3.8/site-packages/torch-1.7.1-py3.8-linux-x86_64.egg/torch/autograd/__init__.py:130: in backward
    Variable._execution_engine.run_backward(
.venv/lib/python3.8/site-packages/torch-1.7.1-py3.8-linux-x86_64.egg/torch/autograd/function.py:89: in apply
    return self._forward_cls.backward(self, *args)  # type: ignore
edgify/functional/conv2d.py:24: in backward
    grad_input = torch.nn.grad.conv2d_input(input.shape, weights, grad_output,
.venv/lib/python3.8/site-packages/torch-1.7.1-py3.8-linux-x86_64.egg/torch/nn/grad.py:163: in conv2d_input
    grad_input_padding = _grad_input_padding(grad_output, input_size, stride,
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

grad_output = tensor([[[[ 1.1969e-06, -1.1051e-06,  2.7628e-08,  ...,  2.7842e-06,
            1.4568e-06, -3.0673e-07],
          [...594e-07],
          [ 4.0056e-07, -2.0659e-05, -1.7700e-05,  ...,  1.7184e-06,
           -6.6292e-06, -1.0275e-05]]]])
input_size = [7, 7], stride = (2, 2), padding = (0, 0), kernel_size = (1, 1), dilation = (1, 1)

    def _grad_input_padding(grad_output, input_size, stride, padding, kernel_size, dilation=None):
        if dilation is None:
            # For backward compatibility
            warnings.warn("_grad_input_padding 'dilation' argument not provided. Default of 1 is used.")
            dilation = [1] * len(stride)
    
        input_size = list(input_size)
        k = grad_output.dim() - 2
    
        if len(input_size) == k + 2:
            input_size = input_size[-k:]
        if len(input_size) != k:
            raise ValueError("input_size must have {} elements (got {})"
                             .format(k + 2, len(input_size)))
    
        def dim_size(d):
            return ((grad_output.size(d + 2) - 1) * stride[d] - 2 * padding[d] + 1
                    + dilation[d] * (kernel_size[d] - 1))
    
        min_sizes = [dim_size(d) for d in range(k)]
        max_sizes = [min_sizes[d] + stride[d] - 1 for d in range(k)]
        for size, min_size, max_size in zip(input_size, min_sizes, max_sizes):
            if size < min_size or size > max_size:
>               raise ValueError(
                    ("requested an input grad size of {}, but valid sizes range "
                     "from {} to {} (for a grad_output of {})").format(
                         input_size, min_sizes, max_sizes,
                         grad_output.size()[2:]))
E               ValueError: requested an input grad size of [7, 7], but valid sizes range from [13, 13] to [14, 14] (for a grad_output of torch.Size([7, 7]))

.venv/lib/python3.8/site-packages/torch-1.7.1-py3.8-linux-x86_64.egg/torch/nn/grad.py:32: ValueError

Aside from the fact that I don’t know why input sizes were miscalculated, I don’t know why adding ctx.any_variable_name = input fixes the problem!

I appreciate any help!

1 Like