I have been trying to implement a custom convolutional layer.
In order to do that, I’m using torch.nn.functional.conv2d in the forward pass, and both torch.nn.grad.conv2d_weight and torch.nn.grad.conv2d_input in the backward pass.
I started getting OOM exceptions when entering torch.nn.grad.conv2d_weight.
My question is, what exactly is the difference between using:
when MyConv is implemented as follows:
class MyConv(Function): @staticmethod def forward(ctx, x, w): ctx.save_for_backward(x, w) return F.conv2d(x, w) @staticmethod def backward(ctx, grad_output): x, w = ctx.saved_variables x_grad = w_grad = None if ctx.needs_input_grad: x_grad = torch.nn.grad.conv2d_input(x.shape, w, grad_output) if ctx.needs_input_grad: w_grad = torch.nn.grad.conv2d_weight(x, w.shape, grad_output) return x_grad, w_grad
Why would torch.nn.grad.conv2d_weight return an OOM exception when torch.nn.functional.conv2d (that I assume also uses torch.nn.grad.conv2d_weight in the backward pass) did not?