Implementing online hard example mining (OHEM)

Hello Pytorchers!! So I have been trying to implement OHEM and have reached to a somehow (I guess) reasonable version. Due to the lack of examples out there, I decided to share this implementation hoping to get an opinion and a fresh eye!

My question comes in two parts:

  1. Is my solution correct? if not, please advise. It would also be great if anyone could direct me towards an optimized solution in terms of generalizability and computation cost.

  2. Since in my implementation am only mining the hard negative examples, I was wondering whether this can be done to the positive examples too? In other words, does taking the top-k positive samples to compute the loss for the positive class make sense?

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

def hard_mining(neg_output, neg_labels, ratio):

    num_inst = neg_output.size(0)
    num_hard = max(int(ratio * num_inst), 1)

    _, idcs = torch.topk(neg_output, min(num_hard, len(neg_output)))

    neg_output = torch.index_select(neg_output, 0, idcs)
    neg_labels = torch.index_select(neg_labels, 0, idcs)

    return neg_output, neg_labels

class BCE_OHEM(nn.Module):
    def __init__(self, ratio):
        super(BCE_OHEM, self).__init__()
        self.ratio = ratio

    def forward(self, inputs, targets):

        pos_idcs = targets[:, 0] >= 0.5

        pos_output = inputs[pos_idcs]
        pos_labels = targets[pos_idcs]

        neg_idcs = targets[:, 0] < 0.5

        neg_output = inputs[:, 0][neg_idcs]
        neg_labels = targets[:, 0][neg_idcs]

        ## locate top-k hard negatives :
        neg_output, neg_labels = hard_mining(neg_output, neg_labels, self.ratio)  # self.num_hard) # *batch_size (we dont do this here, this was useful for detection)
        ## compute loss for negative class over top-k selection :
        neg_BCE_loss = F.binary_cross_entropy(neg_output, neg_labels, reduction='none')

        ## compute loss over positive class :
        pos_BCE_loss = F.binary_cross_entropy(pos_output, pos_labels, reduction='none')

        ## compute overall loss :
        classify_loss = 0.5 * pos_BCE_loss + 0.5 * neg_BCE_loss

        return torch.mean(classify_loss)

if __name__ == '__main__':

    logits = torch.tensor(np.array([[-0.3], [0.8], [-0.4], [0.99]]), requires_grad=True)
    predictions = torch.sigmoid(logits) ## [[0.4256], [0.6900], [0.4013], [0.7291]]
    targets = torch.tensor(np.array([[0.], [0.], [1.], [1.]]), requires_grad=False)

    print(F.binary_cross_entropy(predictions, targets)) ## tensor(0.7386, dtype=torch.float64, grad_fn=<BinaryCrossEntropyBackward>)

    ohem_BCE = BCE_OHEM(ratio=1.0)
    print(ohem_BCE(predictions, targets)) ## tensor(0.7386, dtype=torch.float64, grad_fn=<MeanBackward0>)

    ohem_BCE = BCE_OHEM(ratio=0.5)
    print(ohem_BCE(predictions, targets)) ## tensor(0.8928, dtype=torch.float64, grad_fn=<MeanBackward0>)

    ohem_BCE(predictions, targets).backward()