Hello Pytorchers!! So I have been trying to implement OHEM and have reached to a somehow (I guess) reasonable version. Due to the lack of examples out there, I decided to share this implementation hoping to get an opinion and a fresh eye!
My question comes in two parts:

Is my solution correct? if not, please advise. It would also be great if anyone could direct me towards an optimized solution in terms of generalizability and computation cost.

Since in my implementation am only mining the hard negative examples, I was wondering whether this can be done to the positive examples too? In other words, does taking the topk positive samples to compute the loss for the positive class make sense?
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
def hard_mining(neg_output, neg_labels, ratio):
num_inst = neg_output.size(0)
num_hard = max(int(ratio * num_inst), 1)
_, idcs = torch.topk(neg_output, min(num_hard, len(neg_output)))
neg_output = torch.index_select(neg_output, 0, idcs)
neg_labels = torch.index_select(neg_labels, 0, idcs)
return neg_output, neg_labels
class BCE_OHEM(nn.Module):
def __init__(self, ratio):
super(BCE_OHEM, self).__init__()
self.ratio = ratio
def forward(self, inputs, targets):
pos_idcs = targets[:, 0] >= 0.5
pos_output = inputs[pos_idcs]
pos_labels = targets[pos_idcs]
neg_idcs = targets[:, 0] < 0.5
neg_output = inputs[:, 0][neg_idcs]
neg_labels = targets[:, 0][neg_idcs]
## locate topk hard negatives :
neg_output, neg_labels = hard_mining(neg_output, neg_labels, self.ratio) # self.num_hard) # *batch_size (we dont do this here, this was useful for detection)
## compute loss for negative class over topk selection :
neg_BCE_loss = F.binary_cross_entropy(neg_output, neg_labels, reduction='none')
## compute loss over positive class :
pos_BCE_loss = F.binary_cross_entropy(pos_output, pos_labels, reduction='none')
## compute overall loss :
classify_loss = 0.5 * pos_BCE_loss + 0.5 * neg_BCE_loss
return torch.mean(classify_loss)
if __name__ == '__main__':
logits = torch.tensor(np.array([[0.3], [0.8], [0.4], [0.99]]), requires_grad=True)
predictions = torch.sigmoid(logits) ## [[0.4256], [0.6900], [0.4013], [0.7291]]
targets = torch.tensor(np.array([[0.], [0.], [1.], [1.]]), requires_grad=False)
print(F.binary_cross_entropy(predictions, targets)) ## tensor(0.7386, dtype=torch.float64, grad_fn=<BinaryCrossEntropyBackward>)
ohem_BCE = BCE_OHEM(ratio=1.0)
print(ohem_BCE(predictions, targets)) ## tensor(0.7386, dtype=torch.float64, grad_fn=<MeanBackward0>)
ohem_BCE = BCE_OHEM(ratio=0.5)
print(ohem_BCE(predictions, targets)) ## tensor(0.8928, dtype=torch.float64, grad_fn=<MeanBackward0>)
ohem_BCE(predictions, targets).backward()