Fails for the following function:
class FRUFunction(Function):
@staticmethod
def forward(ctx, s, a):
k = 1 / (a + s - a * s)
ctx.save_for_backward(s, k, a)
return k * s
@staticmethod
def backward(ctx, grad_output):
s, k, a = ctx.saved_tensors
grad_s = grad_a = None
if ctx.needs_input_grad[0]:
grad_s = a / (a + s - a*s)** 2
return grad_s, grad_a
Check with code:
input = torch.tensor([[0.0971, 0.5413, 0.8107]], dtype=torch.double, requires_grad=True)
alpha = torch.tensor([[0.2904, -0.3183, 0.7001]], dtype=torch.double, requires_grad=False)
fru_input = (
input,
alpha,
)
test = gradcheck(fru, fru_input, eps=1e-3)
It throws runtime exception:
RuntimeError: Jacobian mismatch for output 0 with respect to input 0,
numerical:tensor([[ 2.2495, 0.0000, 0.0000],
[ 0.0000, -2.0370, 0.0000],
[ 0.0000, 0.0000, 0.7869]], dtype=torch.float64)
analytical:tensor([[ 2.2495, 2.2495, 2.2495],
[-2.0370, -2.0370, -2.0370],
[ 0.7869, 0.7869, 0.7869]], dtype=torch.float64)
NOTE: It even fails for
@staticmethod
def forward(ctx, s, a):
k = 1 / (a + s - a * s)
ctx.save_for_backward(s, k, a)
return k
@staticmethod
def backward(ctx, grad_output):
s, k, a = ctx.saved_tensors
grad_s = grad_a = None
if ctx.needs_input_grad[0]:
grad_s = (a-1) / ((a + s - a*s) ** 2)
return grad_s, grad_a
and for
@staticmethod
def forward(ctx, s, a):
k = a + s - a * s
ctx.save_for_backward(s, k, a)
return k
@staticmethod
def backward(ctx, grad_output):
s, k, a = ctx.saved_tensors
grad_s = grad_a = None
if ctx.needs_input_grad[0]:
grad_s = a-1
return grad_s, grad_a
and for
@staticmethod
def forward(ctx, s, a):
k = a * s
ctx.save_for_backward(s, k, a)
return k
@staticmethod
def backward(ctx, grad_output):
s, k, a = ctx.saved_tensors
grad_s = grad_a = None
if ctx.needs_input_grad[0]:
grad_s = a
return grad_s, grad_a