Is it right to implement loss fuction in this way?



Target is z and pi.

I don’t think you want binary cross entropy for this, it doesn’t do what your equation does.

1 Like

Thank you. So how do you implement it in this case?

How about this?

def loss(v_batch, z_batch, p_batch, pi_batch):
    return F.mse_loss(v_batch, z_batch) - \
        torch.matmul(pi_batch.t(), torch.log(p_batch))
1 Like