Top K gradient for Cross Entropy (OHEM)

Having seen a paper talking about mining top 70% gradient for Backpropagation, I am wondering if this strategy can real help improve performance. Somebody call this Online Hard Example Mining (OHEM).

Attached below is my custom Cross_Entropy implementation for calculating top k percentage gradient for binary classification. I have tested it when top_k = 100% and the result is exactly like original nn.cross_entropy()

May I ask, is there a better way to achieve this goal ? Do you think this is a good practice ?


class topk_crossEntrophy(nn.Module):

def __init__(self, top_k=0.7):
    super(topk_crossEntrophy, self).__init__()
    self.loss = nn.NLLLoss()
    self.top_k = top_k
    self.softmax = nn.LogSoftmax()
def forward(self, input, target):
    softmax_result = self.softmax(input)
    loss = Variable(torch.Tensor(1).zero_())
    for idx, row in enumerate(softmax_result):
        gt = target[idx]
        pred = torch.unsqueeze(row, 0)
        cost = self.loss(pred, gt)
        loss =, cost), 0)
    loss = loss[1:]
    if self.k == 1:
        valid_loss = loss
    index = torch.topk(loss, int(self.top_k * loss.size()[0]))
    valid_loss = loss[index[1]]
    return torch.mean(valid_loss)


a = torch.randn((10,2))
b = np.random.randint(2, size=10)
b = torch.from_numpy(b.astype(np.float32)).type(torch.LongTensor)

topk_loss = topk_crossEntrophy()
loss = topk_loss(Variable(a, requires_grad=True), Variable(b))


I guess this is the exact implementation of the OHEM methods. I have seen their videos, it achieves better performance in different aspects while it is easy to implement.

Will the for loop of this method bring about tons of overhead, especially in the segmentation tasks?

Hi, I implemented a version of ohem for torch version >=0.4

class topk_crossEntrophy(_WeightedLoss):

def __init__(self, top_k=0.7, weight=None, size_average=None,
                ignore_index=-100, reduce=None, reduction='none'):
    super(topk_crossEntrophy, self).__init__(weight, size_average, reduce, reduction)
    self.ignore_index = ignore_index
    self.top_k = top_k
    self.loss = nn.NLLLoss(weight=self.weight, 
                ignore_index=self.ignore_index, reduction='none')

def forward(self, input, target):
    loss = self.loss(F.log_softmax(input, dim=1), target)
    if self.top_k == 1:
        return torch.mean(loss)
        valid_loss, idxs = torch.topk(loss, int(self.top_k * loss.size()[0]))    
        return torch.mean(valid_loss)