I calculate chamfer loss for different parts of the object and would like to mask out for some of the losses. I am confused a little bit. Could you help me?
loss_ch, _ = chamfer_distance(src_0, trg_0) * mask
out_0: torch.Size([2, 1538, 3])
trg_0: torch.Size([2, 1538, 3])
mask torch.Size() # just a binary 0 or 1
batch_size = 2