Gradcheck fails for custom function

Hello all,

I’m new to PyTorch so this might turn out to be an easy fix, but I’m not able to get it working.

I’m trying to implement a custom logsigmoid function based on table lookup. I have the following code for my custom function:

class LogSigmoid(Function):
    @staticmethod
    def forward (ctx, tensor):
        ctx.save_for_backward(tensor)
        return TableLookup.logsigmoid (tensor)

    @staticmethod
    def backward (ctx, grad_output):
        tensor, = ctx.saved_tensors
        grad_input = 1 - TableLookup.sigmoid(tensor)
        return grad_output * grad_input

from torch.autograd import gradcheck
input = (torch.randn(3,3,dtype=torch.double,requires_grad=True),)
test = gradcheck(LogSigmoid.apply, input, eps=1e-6, atol=1e-4)
print(test)

This fails with the error message:

RuntimeError: Jacobian mismatch for output 0 with respect to input 0,
numerical:tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=torch.float64)
analytical:tensor([[0.7356, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.4344, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.4974, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.2663, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.3699, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.6546, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.4759, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.2895, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.3435]],
       dtype=torch.float64)

If I replace the call to TableLookup.logsigmoid by torch.nn.functional.logsigmoid, then gradcheck works as expected. This led me to think that my lookup-based implementation might be incorrect, but I calculated the output using both torch.nn.functional.logsigmoid and LogSigmoid.apply and they are pretty much the same.

from torch.autograd import gradcheck
input = (torch.randn(3,3,dtype=torch.double,requires_grad=True),)
print (F.logsigmoid(input[0]))
print (LogSigmoid.apply(input[0]))
test = gradcheck(LogSigmoid.apply, input, eps=1e-6, atol=1e-4)
tensor([[-0.2201, -1.0244, -0.1658],
        [-0.4461, -0.2707, -0.3701],
        [-0.9809, -0.5292, -0.3277]], dtype=torch.float64,
       grad_fn=<LogSigmoidBackward>)
tensor([[-0.2201, -1.0245, -0.1659],
        [-0.4461, -0.2707, -0.3701],
        [-0.9809, -0.5292, -0.3277]], dtype=torch.float64,
       grad_fn=<LogSigmoidBackward>)

For completeness, the TableLookup class is as follows:

def sigm (x): return 1/(1+np.exp(-x))

class TableLookup (object):
    TABLE_SIZE = 100000
    MAX_EXP = 10
    
    EXP_TABLE = sigm ((np.arange (TABLE_SIZE) / TABLE_SIZE * 2 - 1) * MAX_EXP) #sigmoid table
    LOG_TABLE = np.log (EXP_TABLE) #logsigmoid table
    EXP_TABLE = torch.tensor (EXP_TABLE, dtype=torch.double) # TODO: how to do this is in the *device* memory?
    LOG_TABLE = torch.tensor (LOG_TABLE, dtype=torch.double)
    
    @classmethod
    def sigmoid (cls, tensor):
        """ For a given query tensor, calculate the sigmoid values"""
        # clamp
        tensor = torch.clamp (tensor, -TableLookup.MAX_EXP, ((TableLookup.TABLE_SIZE-1) / TableLookup.TABLE_SIZE * 2 - 1) * TableLookup.MAX_EXP)
        # map to indices
        indices = ((tensor + TableLookup.MAX_EXP) * TableLookup.TABLE_SIZE/2/TableLookup.MAX_EXP).long()
        # lookup
        return TableLookup.EXP_TABLE[indices]
        
    @classmethod
    def logsigmoid (cls, tensor):
        """ For a given query tensor, calculate the logsigmoid values as:
        logsigmoid (x) = x              if x < -MAX_EXP
                       = lookup (x)     if x >= -MAX_EXP and x < MAX_EXP
                       = 0              if x >= MAX_EXP
        """
        # if x < -TableLookup.MAX_EXP:
        mask = tensor < -TableLookup.MAX_EXP
        t1 = tensor * mask.double()
        
        tensor = torch.clamp (tensor, -TableLookup.MAX_EXP, ((TableLookup.TABLE_SIZE-1) / TableLookup.TABLE_SIZE * 2 - 1) * TableLookup.MAX_EXP)
        # map to indices
        indices = ((tensor + TableLookup.MAX_EXP) * TableLookup.TABLE_SIZE/2/TableLookup.MAX_EXP).long()
        # lookup
        t2 = TableLookup.LOG_TABLE[indices]
        # everywhere the mask is 1 should be converted to zero in t2
        t2[mask] = 0
        return t1 + t2.double()

Can someone explain what I might be doing wrong and how I may be able to fix it?