Masked instance norm

Hello. I would like to use instance normalization (1d), however I cannot use nn.InstanceNorm1d because my objects are masked. For example, I have an input of shape batch_size (N), num_objects (L), features(C), and each batch has different number of objects, and the number of objects is not fixed. Therefore, I have a boolean mask of shape batch_size (N), num_objects (L) for that. So using nn.InstanceNorm1d would compute the mean and std across the padded objects, which I don’t want. Therefore, I’ve implemented it from scratch. But it dosen’t seem to give me the accuracy I am expecting (i know the accuracy because i am re-implementing a paper). It would be helpful to check my code.

EDIT: I realized that the problem was that detaching the mean and var calculation is necessary. I placed the detach below in the code.

def masked_instance_norm(x, mask, eps = 1e-5):
    """
    x of shape: [batch_size (N), num_objects (L), features(C)]
    mask of shape: [batch_size (N), num_objects (L)]
    """
    mask = mask.float().unsqueeze(-1)  # (N,L,1)
    mean = (torch.sum(x * mask, 1) / torch.sum(mask, 1))   # (N,C)
    mean = mean.detach()
    var_term = ((x - mean.unsqueeze(1).expand_as(x)) * mask)**2  # (N,L,C)
    var = (torch.sum(var_term, 1) / torch.sum(mask, 1))  #(N,C)
    var = var.detach()
    mean_reshaped = mean.unsqueeze(1).expand_as(x)  # (N, L, C)
    var_reshaped = var.unsqueeze(1).expand_as(x)    # (N, L, C)
    ins_norm = (x - mean_reshaped) / torch.sqrt(var_reshaped + eps)   # (N, L, C)
    return ins_norm

Verifying using PyTorch nn.InstanceNorm (without mask):

# PyTorch nn.InstanceNorm
m = nn.InstanceNorm1d(100)
input = torch.randn(20, 100, 40)
pytorch_output = m(input)

# Mines
input = input.permute(0,2,1)  # shape required by nn.InstanceNorm1d
mask = torch.ones(20,40)   # mask is all ones for comparison purpose
my_output = masked_instance_norm(input, mask)
my_output = my_output.permute(0,2,1) # back to original shape

They give the same result

1 Like