Normalization on width dimension only


I want to use a custom version of instance normalization which performs normalization only on W dimension of activation maps of size BxCxHxW.

Without reshaping the activation maps InstanceNorm2d is not proper for this.

Any suggestions?

def forward(self, x):
	out = self.InstanceNorm2d(x)        
	return out

def forward(self, x):
	out = torch.reshape(x, (x.shape[0], x.shape[1]*x.shape[2], x.shape[3], 1) )
	out = self.InstanceNorm2d(out)
	out = torch.reshape(out, (out.shape[0], self.c, self.n, -1 ) )
	return out

Adding 2 reshaping operations makes training 2 times slower.

Could you try to manually write the normalization and allow nvFuser to code-gen the code for you similar to this use case?