I’m getting strides like (0, 0)
when writing custom Pytorch autograd functions. Here’s a minimal, reproducible example:
import torch
class Identity(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
return input
@staticmethod
def backward(ctx, grad_output):
print(grad_output.shape, grad_output.stride())
return grad_output
# Create a callable object for the Identity function
identity = Identity.apply
# Example usage
x = torch.randn(3, 3, requires_grad=True)
y = identity(x)
z = y.sum()
z.backward()
print(x.grad) # Should be a 3x3 tensor of ones
This is causing problems / leading to incorrect results when using more complicated autograd functions where I try to make grad_output
contiguous.