Sparse softmax as functional?

I have a torch tensor of shape (batch_size, N). I want to apply functional softmax with dim 1 to this tensor, but I also want it to ignore zeros in the tensor and only apply it to non-zero values (the non-zeros in the tensor are positive numbers). I think what I am looking for is the sparse softmax.

I came up with this code: GitHub, but seems like it uses nn.Module instead of functional.

How can I use that sparse softmax as a functional one (that is, it is not a layer)?

As I see their code, something like that might do the trick.

def sparsemax(input, dim):
    """
    Args:
        input (torch.Tensor): Input tensor. First dimension should be the batch size
    Returns:
        torch.Tensor: [batch_size x number_of_logits] Output tensor
    """
    # Sparsemax currently only handles 2-dim tensors,
    # so we reshape to a convenient shape and reshape back after sparsemax
    input = input.transpose(0, dim)
    original_size = input.size()
    input = input.reshape(input.size(0), -1)
    input = input.transpose(0, 1)
    dim = 1
    ...
    ...

So I can use the whole forward() function, and ignore the backward(), right?

If you ignore it I think that during the back-propagation your model will calculate the gradient in the wrong way.
Hence the role of backward (and thus the use of nn.Module).
Let me check with an example.

Why don’t you want to use the version with nn.Module?

Hi Aerys!

I didn’t look at the github code you linked to, but, in general, if you
have a (properly-written) torch.nn.Module or
torch.autograd.Function, you can instantiate it “on the fly” and
call it, rather than using it as a “layer” in a model.

Thus, for example:

probs = torch.nn.Softmax (dim = 1) (logits)

is the practical equivalent of:

probs = torch.nn.functional.softmax (logits, dim = 1)
# or
probs = torch.softmax (logits, dim = 1)

Best.

K. Frank

Hi KFrank, thanks for your help. Also as pascal_notsawo said, I need it separately to apply to a tensor and do not need its gradients. So in that case, I can warp it with no_grad and it will be fine right?

Hi Aerys!

Roughly speaking, whether computations are tracked for potential future
gradient computation depends on the tensor being passed through the
function, rather than on the function itself.

In short, I wouldn’t worry about it. Even if you pass a tensor with
requires_grad = True through your function, your function’s
backward() will only be called if you call backward() on something
that depends on the result of calling your function.

(If you do call your function on a tensor with requires_grad = True
and you know that you won’t be needing the associated gradient,
you can wrap your function call in a with torch.no_grad(): block
to save on the overhead of building that branch of the computation
graph. But I would wrap the function call rather than the innards of
the function itself.)

Best.

K. Frank