Is there any faster way to using advanced indexing?
import torch
sp = torch.rand(3, 4, 4)
indices = sp.max(dim=0)[1]
gt = torch.zeros_like(sp)
for idx, sp_i in enumerate(gt):
sp_i.masked_fill_(indices == idx, 1.)
Is there any faster way to using advanced indexing?
import torch
sp = torch.rand(3, 4, 4)
indices = sp.max(dim=0)[1]
gt = torch.zeros_like(sp)
for idx, sp_i in enumerate(gt):
sp_i.masked_fill_(indices == idx, 1.)