Is it right to implement loss fuction in this way?



Target is z and pi.

I don’t think you want binary cross entropy for this, it doesn’t do what your equation does.

Thank you. So how do you implement it in this case?

How about this?

def loss(v_batch, z_batch, p_batch, pi_batch):
    return F.mse_loss(v_batch, z_batch) - \
        torch.matmul(pi_batch.t(), torch.log(p_batch))
