Getting Nan after first iteration with custom loss

You can directly print these tensors in the forward pass to get their values for debugging.