[SOLVED] Index based element update for pytorch tensor

Hi, I have a simple snippet of code, which im wondering if there is a way to make it faster:

        cls_score_shifted = torch.zeros((cls_score.shape[0],cls_score.shape[1]),device='cuda')
        for i in range(cls_score.shape[0]):
            for j in range(cls_score.shape[1]):
                if(targets[i] == j):
                    cls_score_shifted[i,j] = cls_score[i,j] + 1
                    cls_score_shifted[i,j] = cls_score[i,j]

I want to conditionally update (shift) my class score if it is the correct (target) logit, and leave the other element alone. is there a quicker way to do this?



This seems a lot faster

        cls_score_mask = torch.zeros((cls_score.shape[0],cls_score.shape[1]),device='cuda')
        for i,mask_row in enumerate(cls_score_mask):
            mask_row[targets[i]] = 1
        cls_score_shifted = cls_score + cls_score_mask

Probably still a better way to do this.

Okay, got it.

        cls_score_mask = torch.zeros_like(cls_score).scatter_(1, targets.unsqueeze(1), 1)
        cls_score_shifted = cls_score + cls_score_mask