How to build a complicate self-defined loss function in PyTorch

My output is prediction with size (bach_size, 15) and I want my loss like:
a * BCE(prediction[:, 14], taret[:, 14]) + BCE(prediction[:, 0:14], taret[:, 0:14]) * (prediction[:, 14] == taret[:, 14] )

To be more specific, I want to use the highest bit as an indicator, so the loss has two parts. First part is the loss of the indicator bit, thus BCE(prediction[:, 14], taret[:, 14]). The second part is loss of lower 14 bits when indicator bit is right, and it is BCE(prediction[:, 0:14], taret[:, 0:14]) * (prediction[:, 14] == taret[:, 14] ).

How to implement it in Pytorch?

Well exactly as you wrote it:

bce_loss = nn.BCELoss()
loss = a * bce_loss(prediction[:, 14], taret[:, 14]) + bce_loss(prediction[:, 0:14], taret[:, 0:14]) * (prediction[:, 14] == taret[:, 14] )

Thanks for your reply. I will try it!

In fact, I do not think this will work as I want. Since I want to calculate the loss of each bit in lower 14 bits only when its predict is equal to the target, in a bit-wise manner.

For example, if target[:, 14] = [0 1 1] and prediction[:, 14] = [0 1 0], I would like to calculate the second part loss of first two samples without the third one since its prediction is different from the third one.