I am learning autograd
and I was following the tutorials extending pytorch.
I would like to compute gradients wrt x and y for the following function.
import torch
from torch.autograd import Function
class Func(Function):
@staticmethod
def forward(x, v):
a = torch.tanh(x)
o = torch.matmul(a, v)
return o, a
@staticmethod
def setup_context(ctx, inputs, outputs):
x, v = inputs
o, a = outputs
ctx.save_for_backward(x, v, a)
@staticmethod
def backward(ctx, grad_o, grad_a):
x, v, a = ctx.saved_tensors
dL_da = torch.matmul(grad_o, v.transpose(1, 0))
dL_dv = torch.matmul(a.transpose(1, 0), grad_o)
dL_dx = (dL_da / ((torch.cosh(x)) ** 2))
return dL_dx, dL_dv
When I run gradcheck
as
x = torch.randn(1, 3)
func = Func.apply
x = torch.randn(2, 3, requires_grad=True, dtype=torch.double)
y = torch.randn(3, 2, requires_grad=True, dtype=torch.double)
test = gradcheck(func, (x, y), eps=1e-6, atol=1e-4)
it throws me the following traceback:
GradcheckError: Jacobian mismatch for output 1 with respect to input 0,
numerical:tensor([[0.4000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.7561, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.9592, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.9947, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.9674, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.4168]], dtype=torch.float64)
analytical:tensor([[0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0.]], dtype=torch.float64)
I am not sure what I am doing wrong here. Can someone help me figure out?
Other notes
I tried out simple functions - tanh
and matmul
and they seem to work fine:
class Mm(Function):
@staticmethod
def forward(x, y):
return torch.matmul(x, y)
@staticmethod
def setup_context(ctx, inputs, output):
x, y = inputs
ctx.save_for_backward(x, y)
@staticmethod
def backward(ctx, grad_output):
x, y = ctx.saved_tensors
dx = torch.matmul(grad_output, y.transpose(1, 0))
dy = torch.matmul(x.transpose(1, 0), grad_output)
return dx, dy
class Tanh(Function):
@staticmethod
def forward(x):
return torch.tanh(x)
@staticmethod
def setup_context(ctx, inputs, output):
x, = inputs
ctx.save_for_backward(x)
@staticmethod
def backward(ctx, grad_y):
x, = ctx.saved_tensors
return grad_y * (1 / ((torch.cosh(x)) ** 2))
mm = Mm.apply
x = torch.randn(2, 3, requires_grad=True, dtype=torch.double)
y = torch.randn(3, 2, requires_grad=True, dtype=torch.double)
test = gradcheck(mm, (x, y), eps=1e-6, atol=1e-4)
tanh = Tanh.apply
x = torch.randn(1, 3, requires_grad=True, dtype=torch.double)
test = gradcheck(tanh, x, eps=1e-6, atol=1e-4)