Hi, I have a simple snippet of code, which im wondering if there is a way to make it faster:
cls_score_shifted = torch.zeros((cls_score.shape[0],cls_score.shape[1]),device='cuda')
for i in range(cls_score.shape[0]):
for j in range(cls_score.shape[1]):
if(targets[i] == j):
cls_score_shifted[i,j] = cls_score[i,j] + 1
else:
cls_score_shifted[i,j] = cls_score[i,j]
I want to conditionally update (shift) my class score if it is the correct (target) logit, and leave the other element alone. is there a quicker way to do this?
Thanks!
Edit:
This seems a lot faster
cls_score_mask = torch.zeros((cls_score.shape[0],cls_score.shape[1]),device='cuda')
for i,mask_row in enumerate(cls_score_mask):
mask_row[targets[i]] = 1
cls_score_shifted = cls_score + cls_score_mask
Probably still a better way to do this.
Edit2:
Okay, got it.
cls_score_mask = torch.zeros_like(cls_score).scatter_(1, targets.unsqueeze(1), 1)
cls_score_shifted = cls_score + cls_score_mask