I’m coding a multi-task classification model, with my customized loss func.
While the number of classes of each task are same, each label variables have different class weights buffers. So I instanciated loss_func objs for every label var respectively, and let them to hold their own weights buffer, by means of register_buffer()
.
Now I need to confirm whether or not register_buffer()
would create multiple instances with a same name, and we can simply access the correct ones just like accessing a normal member var of every func obj.
In my example as shown later, although I called register_buffer
with a same arg value(name=‘cls_wts’) for all instances of SoftLabelLoss
, in fact each obj has its own member var named ‘cls_wts’ with different weight values(No global values, No overwrites)?
If so, I do NOT need to find the correct buffer by iterating model.buffers()
when it computing loss?
Thanks!!!*
class SoftLabelLoss( nn.Module ):
def __init__( self, class_cnt:int, weights:np.ndarray=None ):
super().__init__()
cls_idx = torch.arange( class_cnt, dtype=torch.float32, requires_grad=False )
cls_idx = cls_idx.unsqueeze( 0 )
self.register_buffer( name='cls_idx', tensor=cls_idx )
if( weights ):
cls_wts = torch.tensor( weights, dtype=torch.float32, requires_grad=False )
self.register_buffer( name='cls_wts', tensor=cls_wts )
else:
self.cls_wts = None
def forward( self, logits, target ):
with torch.no_grad():
ftgts = target.unsqueeze( 1 ).float() # batch_size * 1
dists = torch.abs( ftgts - self.cls_idx ) # ( batch_size, class_cnt )
unnorm = torch.exp( -dists )
soft_labs = unnorm / unnorm.sum( dim=1, keepdim=True )
log_probs = F.log_softmax( logits, dim=1 ) # ( batch_size, class_cnt )
loss = - ( soft_labs * log_probs ).sum( dim=1 ) # ( batch_size, )
if( self.cls_wts ):
loss *= self.cls_wts[target]
return loss.mean()