Hi,
I have been trying to implement a custom convolutional layer.
In order to do that, I’m using torch.nn.functional.conv2d in the forward pass, and both torch.nn.grad.conv2d_weight and torch.nn.grad.conv2d_input in the backward pass.
I started getting OOM exceptions when entering torch.nn.grad.conv2d_weight.
My question is, what exactly is the difference between using:
torch.nn.functional.conv2d(x, w)
and
MyConv().apply(x, w)
when MyConv is implemented as follows:
class MyConv(Function):
@staticmethod
def forward(ctx, x, w):
ctx.save_for_backward(x, w)
return F.conv2d(x, w)
@staticmethod
def backward(ctx, grad_output):
x, w = ctx.saved_variables
x_grad = w_grad = None
if ctx.needs_input_grad[0]:
x_grad = torch.nn.grad.conv2d_input(x.shape, w, grad_output)
if ctx.needs_input_grad[1]:
w_grad = torch.nn.grad.conv2d_weight(x, w.shape, grad_output)
return x_grad, w_grad
Why would torch.nn.grad.conv2d_weight return an OOM exception when torch.nn.functional.conv2d (that I assume also uses torch.nn.grad.conv2d_weight in the backward pass) did not?
Thanks.