Efficient way to normalize a tensor by channel

Hello, I want to normalize an input tensor by the maximum value of each sample.
For example, if the input is

x = torch.rand(4,1,5,5)

I can do something like

max_values = x.view(4, -1).max(dim=1, keepdim=True)[0]

to get the max values of each sample.
Then I can get the normalized tensor by

norm_x = x / (max_values * torch.ones(4, 25)).view(4,1,5,5)

But this looks really ugly and not efficient. Could you please suggest a better way to do this?



I’m sure there’s a faster way to do this, you can speed up your code by using broadcasting (rather than allocating a new Tensor of ones to match the size of each Tensor).

def slow_func(x):
  numel_per_sample = x.size()[1:].numel()
  max_values = x.view(x.shape[0], -1).max(dim=-1, keepdim=True)[0]
  norm_x = x / (max_values * torch.ones(x.shape[0], numel_per_sample)).view(x.shape)
  return norm_x
def fast_func(x):
  max_values = x.reshape(x.shape[0],-1).max(dim=-1,keepdim=True)[0]
  norm_x = x / max_values.unsqueeze(2).unsqueeze(3)
  return norm_x

I ran it for a larger batch size (1000) and here’s the results I got.


%timeit slow_func(x) #37.2 µs ± 168 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit fast_func(x) #24.9 µs ± 113 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

I’m somewhat surprised torch.max doesn’t work over multiple dimensions

1 Like

The code looks promising and very clean. Thank you so much!