My output is prediction with size (bach_size, 15) and I want my loss like:
a * BCE(prediction[:, 14], taret[:, 14]) + BCE(prediction[:, 0:14], taret[:, 0:14]) * (prediction[:, 14] == taret[:, 14] )
To be more specific, I want to use the highest bit as an indicator, so the loss has two parts. First part is the loss of the indicator bit, thus BCE(prediction[:, 14], taret[:, 14]). The second part is loss of lower 14 bits when indicator bit is right, and it is BCE(prediction[:, 0:14], taret[:, 0:14]) * (prediction[:, 14] == taret[:, 14] ).
In fact, I do not think this will work as I want. Since I want to calculate the loss of each bit in lower 14 bits only when its predict is equal to the target, in a bit-wise manner.
For example, if target[:, 14] = [0 1 1] and prediction[:, 14] = [0 1 0], I would like to calculate the second part loss of first two samples without the third one since its prediction is different from the third one.