this is my Custom Loss, but the backward will not work.
def OthelloLoss_(pred, label, th):
pred = torch.where(pred > th, True, False)
label = torch.where(label == 1.0, True, False)
#ex: p:110 g:011
pred_gt_or = torch.bitwise_or(pred, label) #pred+gt where 1 =>110 or 011 = 111
denominator = torch.sum(torch.where(pred_gt_or == True, 1, 0),1) #total pred+gt where 1 position =>p or q:111 => 3
pred_gt_xor = torch.bitwise_xor(pred, label) #not correct position =>110 xor 011 = 101
molecular = torch.sum(torch.where(pred_gt_xor == True, 1, 0),1) #total pred not correct num =>p xor q:101 => 2
position_loss = molecular / denominator #position loss = (not correct num) / (total postion num)
position_loss = torch.nan_to_num(position_loss, nan=0.0)
position_batch_loss = torch.div(torch.sum(position_loss), len(position_loss))
wrong_and_xor = torch.bitwise_and(pred,pred_gt_xor) #wrong postion =>p and(p xor q) = 110 and 101 = 100
wrong_num = torch.sum(torch.where(wrong_and_xor == True, 1, 0),1) #wrong num => p and(p xor q) = 100 => 1
wrong_penalty_loss = 0.5 * wrong_num
wrong_penalty_batch_loss = torch.div(torch.sum(wrong_penalty_loss), len(wrong_penalty_loss))
total_loss = position_batch_loss + wrong_penalty_batch_loss
return total_loss
every epoch loss:
Epoch: 1 | train_loss: 12.5797 | train_acc: 4.45% | val_loss: 12.583 | val_acc: 4.45%
Epoch: 2 | train_loss: 12.5797 | train_acc: 4.45% | val_loss: 12.583 | val_acc: 4.45%
Epoch: 3 | train_loss: 12.5797 | train_acc: 4.45% | val_loss: 12.583 | val_acc: 4.45%
Epoch: 4 | train_loss: 12.5797 | train_acc: 4.45% | val_loss: 12.583 | val_acc: 4.45%
…