I have some problems with pytorch gradcheck.
class KL(Function):
# vectorized
# input is batch_size x dim
# Note that both forward and backward are @staticmethods
@staticmethod
def forward(ctx, logits, beta=1):
log_z = 94.358391688507595 # partition function for dim=300
ctx.save_for_backward(logits)
ctx.beta = beta
p = logits.sigmoid()
# more computationally stable entropy
entropy = (-logits * (p - 1) + F.softplus(-logits).data)
entropy[entropy != entropy] = 0 # when probs are ones or zeros -> entropy is zero
a = p[:, :-1]
b = p[:, 1:]
binary_potentials = (a * (1 - b) + (1 - a) * b)
out = entropy.neg().sum(dim=1) + log_z
out = out + beta * binary_potentials.sum(dim=1)
out = out.sum(dim=0, keepdim=True)
return out
@staticmethod
def backward(ctx, grad_output):
logits = ctx.saved_variables[0]
beta = ctx.beta
logits_exp = torch.exp(logits)
term_1 = (logits_exp * logits) / (logits_exp + 1)**2
p = logits.sigmoid()
prev = torch.zeros(*p.size())
prev[:, 1:] = p.data[:, :-1]
prev = Variable(prev)
nxt = torch.zeros(*p.size())
nxt[:, :-1] = p.data[:, 1:]
nxt = Variable(nxt)
if p.is_cuda:
prev = prev.cuda()
nxt = nxt.cuda()
mul_term = ( 2 - 2 * prev - 2 * nxt)
mul_term[:, 0] = mul_term[:, 0] - 1
mul_term[:, -1] = mul_term[:, -1] - 1
term_2 = beta * p * (1 - p) * mul_term
grad_input = term_1 + term_2
grad_input[grad_input.data != grad_input.data] = 0
return grad_output * grad_input, None
I have custom function showed above, and gradcheck fails:
logits = Variable(torch.randn(5, 300), requires_grad=True)
inpt = (logits, 1)
gradcheck(KL.apply, inpt, eps=1e-6, atol=1e-4)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/usr/local/lib/python3.6/site-packages/torch/autograd/gradcheck.py", line 181, in gradcheck
return fail_test('for output no. %d,\n numerical:%s\nanalytical:%s\n' % (j, numerical, analytical))
File "/usr/local/lib/python3.6/site-packages/torch/autograd/gradcheck.py", line 166, in fail_test
raise RuntimeError(msg)
RuntimeError: for output no. 0,
numerical:(
0
0
0
⋮
0
0
0
[torch.FloatTensor of size 1500x1]
,)
analytical:(
-0.1566
-0.0927
0.1289
⋮
-0.0009
0.2219
-0.2261
[torch.FloatTensor of size 1500x1]
,)
But when I check function and its gradients with scipy.optimize.check_grad all is ok and error is less than 1e-6. As I see, the problem is with finite differences, but I can’t understand why. Forward pass seems legit.
Pytorch versions: 0.3.0.post4, 0.2.0.post4’