It seems that when we write our custom autograd function, we have to make sure the tensor shapes of input and output in the forward function stay the same. Otherwise, it will cause RuntimeError while calling backward. Here is a minimal example:
class MyFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
y = torch.sum(x, dim=1)
return y
@staticmethod
def backward(ctx, grad_output):
g = grad_output * 2
return g
f = MyFunc()
a = torch.randn((3, 4), requires_grad=True)
b = f.apply(a)
c = torch.sum(b)
c.backward()
This example will cause the following RuntimeError:
RuntimeError: Function MyFuncBackward returned an invalid gradient at index 0 - got [3] but expected shape compatible with [3, 4]
So does it mean in the forward function, input and output (x and y in the example) have to be of the same shape? What if I want to write a function that changes the input shape in the forward?