Suggested Method for Replacing Forward Function of Conv2d

I’m interested in subbing out pytorch’s default convolution function for one generated by tvm (https://github.com/dmlc/tvm) in order to look at potential speed-ups. Ideally, this would be part of a custom autograd Function that uses a new forward call and conv2d’s default backwards computation. However, I’m having a very hard time figuring out what to do for backward. It seems that ConvNdBackward cant be constructed from python. My other thought was to just implement the gradient myself using conv and convtranspose functions, but because these are part of nn they don’t seem to be compatible inside of Functions. I’m at a loss for how to proceed here, any ideas on how I can get a working conv backwards function with my own custom forward function?

A more specific problem I’m having doing this is that when I use a Conv function inside of forward or backward function, such as

class TestConv(torch.autograd.Function):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1):
    self.ConvFct = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups)
def forward(self, input):
    output = self.ConvFct(Variable(input))
    return output
def backward(self, grad_output):
    return grad_output

I get a RuntimeError: _Map_base::at. If I could use conv functions this way I could just implement the backprop rule myself.

Could you provide a code snippet that causes the RuntimeError to happen?

Yes, it’s possible to reuse the Conv backwards code from python with something like the following:

f = torch._C._functions.ConvNdBackward
f(...)

Python is unable to even construct the ConvNdBackward function,

f = torch._C._functions.ConvNdBackward()
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-56-d0c538bbbaa0> in <module>()
----> 1 f = torch._C._functions.ConvNdBackward()

RuntimeError: Cannot construct

The correct usage is

f = torch._C._functions.ConvNdBackward
f(...)

or
torch._C._functions.ConvNdBackward(...)
where ... are the arguments to ConvNdBackward.

torch._C._functions.ConvNdBackward Isn’t a Function, it’s a direct call to the backend implementation of ConvNd

input_new=Variable(torch.ones(1,1,1,1))
weight_new = Variable(torch.ones(1,1,1,1))
Function = torch._C._functions.ConvNdBackward
Function(input_new, weight_new, 1, 1)

Thanks a lot for pointing this out, do you mind give us a simple example how to call this function say in MNIST ? For the example above it said not construct …

Interesting. I looked into the ConvNdBackwards and I don’t think it can’t be called the same way ConvNd can be from python (it wasn’t implemented to be able to be called from Python). My apologies for misleading you.

You might be able to call it through C++ by playing around with aten (source for backwards convolution here), but there isn’t an easier way to do this right now.