I recently found the need for doing masked operations on PyTorch tensors. Here’s how I implemented them (taken from my blog post):
def masked_mean(tensor, mask, dim):
"""Finding the mean along dim"""
masked = torch.mul(tensor, mask) # Apply the mask using an element-wise multiply
return masked.sum(dim=dim) / mask.sum(dim=dim) # Find the average!
def masked_max(tensor, mask, dim):
"""Finding the max along dim"""
masked = torch.mul(tensor, mask)
neg_inf = torch.zeros_like(tensor)
neg_inf[~mask] = -math.inf # Place the smallest values possible in masked positions
return (masked + neg_inf).max(dim=dim)