Hello Pytorchers!! So I have been trying to implement OHEM and have reached to a somehow (I guess) reasonable version. Due to the lack of examples out there, I decided to share this implementation hoping to get an opinion and a fresh eye!
My question comes in two parts:
-
Is my solution correct? if not, please advise. It would also be great if anyone could direct me towards an optimized solution in terms of generalizability and computation cost.
-
Since in my implementation am only mining the hard negative examples, I was wondering whether this can be done to the positive examples too? In other words, does taking the top-k positive samples to compute the loss for the positive class make sense?
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
def hard_mining(neg_output, neg_labels, ratio):
num_inst = neg_output.size(0)
num_hard = max(int(ratio * num_inst), 1)
_, idcs = torch.topk(neg_output, min(num_hard, len(neg_output)))
neg_output = torch.index_select(neg_output, 0, idcs)
neg_labels = torch.index_select(neg_labels, 0, idcs)
return neg_output, neg_labels
class BCE_OHEM(nn.Module):
def __init__(self, ratio):
super(BCE_OHEM, self).__init__()
self.ratio = ratio
def forward(self, inputs, targets):
pos_idcs = targets[:, 0] >= 0.5
pos_output = inputs[pos_idcs]
pos_labels = targets[pos_idcs]
neg_idcs = targets[:, 0] < 0.5
neg_output = inputs[:, 0][neg_idcs]
neg_labels = targets[:, 0][neg_idcs]
## locate top-k hard negatives :
neg_output, neg_labels = hard_mining(neg_output, neg_labels, self.ratio) # self.num_hard) # *batch_size (we dont do this here, this was useful for detection)
## compute loss for negative class over top-k selection :
neg_BCE_loss = F.binary_cross_entropy(neg_output, neg_labels, reduction='none')
## compute loss over positive class :
pos_BCE_loss = F.binary_cross_entropy(pos_output, pos_labels, reduction='none')
## compute overall loss :
classify_loss = 0.5 * pos_BCE_loss + 0.5 * neg_BCE_loss
return torch.mean(classify_loss)
if __name__ == '__main__':
logits = torch.tensor(np.array([[-0.3], [0.8], [-0.4], [0.99]]), requires_grad=True)
predictions = torch.sigmoid(logits) ## [[0.4256], [0.6900], [0.4013], [0.7291]]
targets = torch.tensor(np.array([[0.], [0.], [1.], [1.]]), requires_grad=False)
print(F.binary_cross_entropy(predictions, targets)) ## tensor(0.7386, dtype=torch.float64, grad_fn=<BinaryCrossEntropyBackward>)
ohem_BCE = BCE_OHEM(ratio=1.0)
print(ohem_BCE(predictions, targets)) ## tensor(0.7386, dtype=torch.float64, grad_fn=<MeanBackward0>)
ohem_BCE = BCE_OHEM(ratio=0.5)
print(ohem_BCE(predictions, targets)) ## tensor(0.8928, dtype=torch.float64, grad_fn=<MeanBackward0>)
ohem_BCE(predictions, targets).backward()