Loss masking chamfer distance


I calculate chamfer loss for different parts of the object and would like to mask out for some of the losses. I am confused a little bit. Could you help me?

loss_ch, _ = chamfer_distance(src_0, trg_0) * mask[0]

out_0: torch.Size([2, 1538, 3])
trg_0: torch.Size([2, 1538, 3])
mask[0] torch.Size([2]) # just a binary 0 or 1

batch_size = 2