Yes, I agree.
I must say returning nan gradient while computing std on all constant tensor is just so strange! The natural way is to return all 0 gradient, isn’t it? Maybe PyTorch should fix it.
For others who have encountered the same problem, I write my own CustomStd
like this:
class CustomStd(torch.autograd.Function):
@staticmethod
def forward(ctx, input, dim, eps=1e-5, unbiased=True, keepdim=False):
dev = input - input.mean(dim=dim, keepdim=True)
ctx.save_for_backward(input)
ctx.eps=eps
ctx.dev = dev
ctx.numdim = input.dim()
ctx.numel = functools.reduce(lambda x, y: x * y, [input.size(d) for d in dim])
if unbiased:
ctx.numel -= 1
ctx.std = torch.sqrt(torch.sum(dev * dev, dim=dim, keepdim=True) / ctx.numel)
return ctx.std if keepdim else ctx.std.squeeze()
@staticmethod
def backward(ctx, grad_output):
input,= ctx.saved_tensors
grad_input = grad_output
for i in range(grad_output.dim(), ctx.numdim):
grad_input = grad_input.unsqueeze(i)
grad_input = ctx.dev * (ctx.numel - 1) / (ctx.numel**2) / (ctx.std + ctx.eps) * grad_input
return grad_input, None, None, None, None
It can pass a test like this:
input = torch.randn(16, 256, 14, 14, dtype=torch.float, requires_grad=True)
std1 = input.std(dim=(2,3))
std2 = CustomStd.apply(input, (2,3))
torch.sum(std1).backward(retain_graph=True)
grad1= input.grad
input.grad=None
torch.sum(std2).backward()
grad2 = input.grad
############### TEST
torch.allclose(std1, std2) # True
torch.allclose(grad1, grad2, atol=1e-3) # True
In the second test (gradient), I must set atol=1e-3 to make the test return true. So I guess there is a subtle difference between CustomStd
and torch.std
.
Also, now CustomStd
can return 0 gradient for all constant tensor.