I would like to add a sparsity regularization to the encodings of my VAE. This means I set an activation vector to 1 for all indices where the encoding is not 0 and then take the mean over the batch dimension to get a distribution. Then I can take a pointwise kullback leibler divergence to the desired sparsity probability.
However, the output of the function does not possess a grad_fn
attribute. So I guess it does not propagate the gradient back. Here is a minimal working example with the function I use to regularize.
import torch
import torch.nn.functional as F
def sparsity_regularizer(enc, sparsity=.05):
activations = torch.zeros(size=enc.size())
# mean activation
activations[torch.nonzero(enc, as_tuple=True)] = 1
# take mean along batch dimension
mean = torch.mean(activations, dim=0)
reg = -F.kl_div(mean.log(), sparsity*torch.ones(size=mean.size()), reduction="sum")
return reg
enc = torch.Tensor([0, 2])
enc.requires_grad = True
out = sparsity_regularizer(enc)
print(out)
What do I need to change?