Torch.distributions.kl.kl_divergence( ) raises NotImplementedError

this function expects Distribution objects, i.e. you’d need to wrap tensors with Categorical(probs=p)

1 Like