Hello. I would like to use instance normalization (1d), however I cannot use nn.InstanceNorm1d because my objects are masked. For example, I have an input of shape
batch_size (N), num_objects (L), features(C), and each batch has different number of objects, and the number of objects is not fixed. Therefore, I have a boolean mask of shape
batch_size (N), num_objects (L) for that. So using nn.InstanceNorm1d would compute the mean and std across the padded objects, which I don’t want. Therefore, I’ve implemented it from scratch. But it dosen’t seem to give me the accuracy I am expecting (i know the accuracy because i am re-implementing a paper). It would be helpful to check my code.
EDIT: I realized that the problem was that detaching the mean and var calculation is necessary. I placed the detach below in the code.
def masked_instance_norm(x, mask, eps = 1e-5): """ x of shape: [batch_size (N), num_objects (L), features(C)] mask of shape: [batch_size (N), num_objects (L)] """ mask = mask.float().unsqueeze(-1) # (N,L,1) mean = (torch.sum(x * mask, 1) / torch.sum(mask, 1)) # (N,C) mean = mean.detach() var_term = ((x - mean.unsqueeze(1).expand_as(x)) * mask)**2 # (N,L,C) var = (torch.sum(var_term, 1) / torch.sum(mask, 1)) #(N,C) var = var.detach() mean_reshaped = mean.unsqueeze(1).expand_as(x) # (N, L, C) var_reshaped = var.unsqueeze(1).expand_as(x) # (N, L, C) ins_norm = (x - mean_reshaped) / torch.sqrt(var_reshaped + eps) # (N, L, C) return ins_norm
Verifying using PyTorch nn.InstanceNorm (without mask):
# PyTorch nn.InstanceNorm m = nn.InstanceNorm1d(100) input = torch.randn(20, 100, 40) pytorch_output = m(input) # Mines input = input.permute(0,2,1) # shape required by nn.InstanceNorm1d mask = torch.ones(20,40) # mask is all ones for comparison purpose my_output = masked_instance_norm(input, mask) my_output = my_output.permute(0,2,1) # back to original shape
They give the same result