Fill a 3D tensor in an arbitrary dimension given 2D indices

Is there any faster way to using advanced indexing?

import torch

sp = torch.rand(3, 4, 4)
indices = sp.max(dim=0)[1]
gt = torch.zeros_like(sp)

for idx, sp_i in enumerate(gt):
    sp_i.masked_fill_(indices == idx, 1.)