Hello. I would like to use instance normalization (1d), however I cannot use nn.InstanceNorm1d because my objects are masked. For example, I have an input of shape batch_size (N), num_objects (L), features(C)
, and each batch has different number of objects, and the number of objects is not fixed. Therefore, I have a boolean mask of shape batch_size (N), num_objects (L)
for that. So using nn.InstanceNorm1d would compute the mean and std across the padded objects, which I don’t want. Therefore, I’ve implemented it from scratch. But it dosen’t seem to give me the accuracy I am expecting (i know the accuracy because i am re-implementing a paper). It would be helpful to check my code.
EDIT: I realized that the problem was that detaching the mean and var calculation is necessary. I placed the detach below in the code.
def masked_instance_norm(x, mask, eps = 1e-5):
"""
x of shape: [batch_size (N), num_objects (L), features(C)]
mask of shape: [batch_size (N), num_objects (L)]
"""
mask = mask.float().unsqueeze(-1) # (N,L,1)
mean = (torch.sum(x * mask, 1) / torch.sum(mask, 1)) # (N,C)
mean = mean.detach()
var_term = ((x - mean.unsqueeze(1).expand_as(x)) * mask)**2 # (N,L,C)
var = (torch.sum(var_term, 1) / torch.sum(mask, 1)) #(N,C)
var = var.detach()
mean_reshaped = mean.unsqueeze(1).expand_as(x) # (N, L, C)
var_reshaped = var.unsqueeze(1).expand_as(x) # (N, L, C)
ins_norm = (x - mean_reshaped) / torch.sqrt(var_reshaped + eps) # (N, L, C)
return ins_norm
Verifying using PyTorch nn.InstanceNorm (without mask):
# PyTorch nn.InstanceNorm
m = nn.InstanceNorm1d(100)
input = torch.randn(20, 100, 40)
pytorch_output = m(input)
# Mines
input = input.permute(0,2,1) # shape required by nn.InstanceNorm1d
mask = torch.ones(20,40) # mask is all ones for comparison purpose
my_output = masked_instance_norm(input, mask)
my_output = my_output.permute(0,2,1) # back to original shape
They give the same result