Custom autograd function with Int tensor input

Hi,

I am implementing custom autograd functions on Sparse Tensors. As I use COO format to encode sparse tensors, the input of my auto grad functions is a pair of tensors, one containing the indices of type torch(.cuda).IntTensor and one containing the values of type torch(.cuda).FloatTensor.

The indice tensor do not have gradient (None) but is used to compute the gradient with respect to the value tensor. It seems that there has been an update in PyTorch 1.3 that do not allow this type of autograd function anymore. Indeed, my custom functions used to work under PyTorch 1.2 but it now raise the following error:

RuntimeError: Expected isFloatingType(grads[i].type().scalarType()) to be true, but got false.  (Could this error message be improved?  If so, please report an enhancement request to PyTorch.)

Here is a small reproducible exemple:

import torch
import torch.nn 
import torch.nn.functional as F



class test_discrete_grad_1(torch.autograd.Function):
    
    @staticmethod
    def forward(ctx, ind, val):
        size_table = ind.max(dim=1)[1]
        ctx.ind = ind
        new_ind = ind
        new_ind[1] = (new_ind[1] + 5)%(size_table[1].item()) 
        new_val = val * 3
        return new_ind, new_val

    @staticmethod
    def backward(ctx, grad_ind, grad_val):
        ### grad_ind is supposed to be None
        return None, 3 * grad_val


class test_discrete_grad_2(torch.autograd.Function):
    
    @staticmethod
    def forward(ctx, ind, val):
        ctx.ind = ind 
        res = torch.zeros((2,))
        res[0] = 2 * val[(ind[1]%2)==0].sum()
        res[1] = 5 * val[(ind[1]%2)==1].sum()
        return res

    @staticmethod
    def backward(ctx, grad):
        ind = ctx.ind
        val = torch.zeros((ind.size(1),))
        val[(ind[1]%2)==0] = 2 * grad[0]
        val[(ind[1]%2)==1] = 5 * grad[1]
        return None, val


def test():

    ind = torch.randint(low=0, high=30, size=(3, 20))
    val = torch.rand(20)
    val.requires_grad = True
    opp_1 = test_discrete_grad_1.apply
    opp_2 = test_discrete_grad_2.apply
    new_ind, new_val = opp_1(ind, val)
    res = opp_2(new_ind, new_val)
    print(res)
    out = res.sum()
    out.backward()
    print(val.grad)

test()

This script runs properly on PyTorch 1.2 and raises the previous error on Pytorch >1.3.

Is there a way to solve this without casting the indice tensor to Float type? (in my application it would requires A LOT of casting operations).

Thank you in advance.

Samuel

Hi,

The wasn’t very clear about this, but the latest version here should be better. In particular, you have to mark non-differentiable outputs explicitly in your custom Functions.
In you case, you should add at the end of test_discrete_grad_1's forward function:

ctx.mark_non_differentiable(new_ind)

Alright I will test with the latest version. Thank you for you answer!

Adding this will fix your issue for all pytorch’s versions.
It’s only the latest doc version that explain the use of this function properly.

Ok got it. Thank you