Bug in backprop with sparse tensors

Hello there,
I’ve faced a strange behavior of torch’s backprop, working completely different with and without sparse tensors. Consider this code:

import torch
import numpy as np
import scipy

target = scipy.sparse.random(18000, 1000, density=0.01)
values = target.data
indices = np.vstack((target.row, target.col))

target = torch.sparse_coo_tensor(indices, values, dtype=torch.float32)
phi = torch.rand(18000, 100, requires_grad=True)
theta = torch.rand(100, 1000, requires_grad=True)

def NLLLoss_sparse(phi, theta, target):
    eps = 1e-9
    output = torch.matmul(phi, theta)
    loss = -torch.sparse.sum(target * torch.log(output + eps))
    return loss

def NLLLoss(phi, theta, target):
    eps = 1e-9
    output = torch.matmul(phi, theta)
    loss = -torch.sum(target.to_dense() * torch.log(output + eps))
    return loss

loss = NLLLoss_sparse(phi, theta, target)
loss.backward()

The difference between the implementation of these two losses is simply in .is_dense() function. However, the loss with this function is backproped properly, while in sparse case it gives me an error:

RuntimeError: Sparse division requires a scalar or zero-dim dense tensor divisor (got shape [18000, 1000] for divisor)

So why doesn’t it work? It is obviously important, because I assume that sparse operations are far faster than backprop of the whole matrix with mostly zeros.

This is expected behavior as sparse division currently doesn’t support dividing by non-Scalars, but maybe we should support this operation indeed. I can file an issue for this if you’d like (or you can if want to do so yourself).

You’re running into this division because the derivative of log is 1 / x. As a workaround you can use custom autograd Function to modify log’s backward formula to avoid this division.

import torch
import numpy as np

target_size = [18000, 1000]
values = torch.rand(100)
indices = torch.rand(2, 100).mul_(torch.tensor(target_size).unsqueeze(-1)).to(torch.int64)

target = torch.sparse_coo_tensor(indices, values, size=target_size, dtype=torch.float32, requires_grad=True)
phi = torch.rand(18000, 100, requires_grad=True)
theta = torch.rand(100, 1000, requires_grad=True)

class Log(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        ctx.save_for_backward(x)
        return torch.log(x)

    @staticmethod
    def backward(ctx, gO):
        x, = ctx.saved_tensors
        # we cannot do gO / x because sparse division requires
        # the divisor to be a scalar
        return gO * (1 / x)

def NLLLoss_sparse(phi, theta, target):
    eps = 1e-9
    output = torch.matmul(phi, theta)
    loss = -torch.sum(target * Log.apply(output + eps))
    return loss

out = NLLLoss_sparse(phi, theta, target)
out.backward()

Thank you, that helped! I will appreciate if you file an issue.

Hello - just curious was this issue filed on pytorch’s github issues? – I’m facing a similar problem where I have a pre-defined sparse weight matrix and the input matrix is dense, utilizing torch.sparse.addmm in the forward method for the sparseWeight parameter, the forward method runs fine, however an error occurs when calling loss.backward

Exception has occurred: RuntimeError
Sparse division requires a scalar or zero-dim dense tensor divisor (got shape [1, 1] for divisor)

Was wondering what the issue may be? And if there is a work around?

Just for reference, this is being used with the built in MSE loss function