i am trying to use OSUMLoss like this :
def get_loss(name):
loss_dict = {
'CrossEntropy': F.cross_entropy,
'SoftCrossEntropy': soft_cross_entropy_loss,
'LabelSmoothingCrossEntropy':
LabelSmoothingCrossEntropy(epsilon=0.1),
'SCE': SymmetricCrossEntropy(),
'BCEWithLogitsLoss': F.binary_cross_entropy_with_logits,
'Coral': coral_loss,
'MSELoss': F.mse_loss,
'MAELoss': F.l1_loss,
'huber': F.smooth_l1_loss,
}
return loss_dict[name]
class OUSMLoss(nn.Module):
'''
Implementation of
Loss with Online Uncertainty Sample Mining:
https://arxiv.org/pdf/1901.07759.pdf
# Params
k: num of samples to drop in a mini batch
loss: loss function name (see get_loss function above)
trigger: the epoch it starts to train on OUSM (please call `.update(epoch)` each epoch)
'''
def __init__(self, k=1, trigger=5, ousm=True):
super(OUSMLoss, self).__init__()
self.k = k
self.loss_name = 'CrossEntropy'
self.loss = get_loss(self.loss_name)
self.trigger = trigger
self.ousm = ousm
def forward(self, logits, targets, indices=None):
logits, targets = _check_input_type(logits, targets, self.loss_name)
bs = logits.shape[0]
if self.ousm and bs - self.k > 0:
losses = self.loss(logits, targets, reduction='none')
if len(losses.shape) == 2:
losses = losses.mean(1)
_, idxs = losses.topk(bs-self.k, largest=False)
losses = losses.index_select(0, idxs)
return losses.mean()
else:
return self.loss(logits, targets)
def update(self, current_epoch):
self.current_epoch = current_epoch
if current_epoch == self.trigger:
self.ousm = True
print('criterion: ousm is True.')
def __repr__(self):
return f'OUSM(loss={self.loss_name}, k={self.k}, trigger={self.trigger}, ousm={self.ousm})'
when i do :
loss = OUSMLoss(logits, target)
loss.backward()
i get,
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
<timed exec> in <module>
<ipython-input-105-6eecc42fef33> in train_epoch(loader, optimizer)
16 loss = OUSMLoss(logits, target)
17
---> 18 loss.backward()
19 optimizer.step()
20
/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in __getattr__(self, name)
592 return modules[name]
593 raise AttributeError("'{}' object has no attribute '{}'".format(
--> 594 type(self).__name__, name))
595
596 def __setattr__(self, name, value):
AttributeError: 'OUSMLoss' object has no attribute 'backward'
how can i solve this issue? @ptrblck