Hello all,
I’m new to PyTorch so this might turn out to be an easy fix, but I’m not able to get it working.
I’m trying to implement a custom logsigmoid
function based on table lookup. I have the following code for my custom function:
class LogSigmoid(Function):
@staticmethod
def forward (ctx, tensor):
ctx.save_for_backward(tensor)
return TableLookup.logsigmoid (tensor)
@staticmethod
def backward (ctx, grad_output):
tensor, = ctx.saved_tensors
grad_input = 1 - TableLookup.sigmoid(tensor)
return grad_output * grad_input
from torch.autograd import gradcheck
input = (torch.randn(3,3,dtype=torch.double,requires_grad=True),)
test = gradcheck(LogSigmoid.apply, input, eps=1e-6, atol=1e-4)
print(test)
This fails with the error message:
RuntimeError: Jacobian mismatch for output 0 with respect to input 0,
numerical:tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=torch.float64)
analytical:tensor([[0.7356, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.4344, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.4974, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.2663, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.3699, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.6546, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.4759, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.2895, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.3435]],
dtype=torch.float64)
If I replace the call to TableLookup.logsigmoid
by torch.nn.functional.logsigmoid
, then gradcheck
works as expected. This led me to think that my lookup-based implementation might be incorrect, but I calculated the output using both torch.nn.functional.logsigmoid
and LogSigmoid.apply
and they are pretty much the same.
from torch.autograd import gradcheck
input = (torch.randn(3,3,dtype=torch.double,requires_grad=True),)
print (F.logsigmoid(input[0]))
print (LogSigmoid.apply(input[0]))
test = gradcheck(LogSigmoid.apply, input, eps=1e-6, atol=1e-4)
tensor([[-0.2201, -1.0244, -0.1658],
[-0.4461, -0.2707, -0.3701],
[-0.9809, -0.5292, -0.3277]], dtype=torch.float64,
grad_fn=<LogSigmoidBackward>)
tensor([[-0.2201, -1.0245, -0.1659],
[-0.4461, -0.2707, -0.3701],
[-0.9809, -0.5292, -0.3277]], dtype=torch.float64,
grad_fn=<LogSigmoidBackward>)
For completeness, the TableLookup
class is as follows:
def sigm (x): return 1/(1+np.exp(-x))
class TableLookup (object):
TABLE_SIZE = 100000
MAX_EXP = 10
EXP_TABLE = sigm ((np.arange (TABLE_SIZE) / TABLE_SIZE * 2 - 1) * MAX_EXP) #sigmoid table
LOG_TABLE = np.log (EXP_TABLE) #logsigmoid table
EXP_TABLE = torch.tensor (EXP_TABLE, dtype=torch.double) # TODO: how to do this is in the *device* memory?
LOG_TABLE = torch.tensor (LOG_TABLE, dtype=torch.double)
@classmethod
def sigmoid (cls, tensor):
""" For a given query tensor, calculate the sigmoid values"""
# clamp
tensor = torch.clamp (tensor, -TableLookup.MAX_EXP, ((TableLookup.TABLE_SIZE-1) / TableLookup.TABLE_SIZE * 2 - 1) * TableLookup.MAX_EXP)
# map to indices
indices = ((tensor + TableLookup.MAX_EXP) * TableLookup.TABLE_SIZE/2/TableLookup.MAX_EXP).long()
# lookup
return TableLookup.EXP_TABLE[indices]
@classmethod
def logsigmoid (cls, tensor):
""" For a given query tensor, calculate the logsigmoid values as:
logsigmoid (x) = x if x < -MAX_EXP
= lookup (x) if x >= -MAX_EXP and x < MAX_EXP
= 0 if x >= MAX_EXP
"""
# if x < -TableLookup.MAX_EXP:
mask = tensor < -TableLookup.MAX_EXP
t1 = tensor * mask.double()
tensor = torch.clamp (tensor, -TableLookup.MAX_EXP, ((TableLookup.TABLE_SIZE-1) / TableLookup.TABLE_SIZE * 2 - 1) * TableLookup.MAX_EXP)
# map to indices
indices = ((tensor + TableLookup.MAX_EXP) * TableLookup.TABLE_SIZE/2/TableLookup.MAX_EXP).long()
# lookup
t2 = TableLookup.LOG_TABLE[indices]
# everywhere the mask is 1 should be converted to zero in t2
t2[mask] = 0
return t1 + t2.double()
Can someone explain what I might be doing wrong and how I may be able to fix it?