Gradient not passing through sparsity regularization

I would like to add a sparsity regularization to the encodings of my VAE. This means I set an activation vector to 1 for all indices where the encoding is not 0 and then take the mean over the batch dimension to get a distribution. Then I can take a pointwise kullback leibler divergence to the desired sparsity probability.

However, the output of the function does not possess a grad_fn attribute. So I guess it does not propagate the gradient back. Here is a minimal working example with the function I use to regularize.

import torch
import torch.nn.functional as F

def sparsity_regularizer(enc, sparsity=.05):
    activations = torch.zeros(size=enc.size())
    # mean activation
    activations[torch.nonzero(enc, as_tuple=True)] = 1
    # take mean along batch dimension
    mean = torch.mean(activations, dim=0)
    reg = -F.kl_div(mean.log(), sparsity*torch.ones(size=mean.size()), reduction="sum")
    return reg

enc = torch.Tensor([0, 2])
enc.requires_grad = True

out = sparsity_regularizer(enc)


What do I need to change?