paper
implement
Target is z and pi.
I don’t think you want binary cross entropy for this, it doesn’t do what your equation does. http://pytorch.org/docs/master/nn.html#torch.nn.BCELoss
Thank you. So how do you implement it in this case?
How about this?
def loss(v_batch, z_batch, p_batch, pi_batch):
return F.mse_loss(v_batch, z_batch) - \
torch.matmul(pi_batch.t(), torch.log(p_batch))