Normalization [0,1] during training

I need to perform a normalization [0,1] over each channel of a tensor [shape(BxCxWxH)] as a part of the model and I wrote this code:

def normalize_channels(_x, every=True):
    out = torch.ones_like(_x)
    for i in range(_x.shape[0]):
        if every:
            for j in range(_x[i].shape[0]):
                min_ = _x[i][j].min()
                max_ = _x[i][j].max()
                out[i][j].copy_((_x[i][j] - min_) * (1 / (max_ - min_)))
            min_ = _x[i].min()
            max_ = _x[i].max()
            out.copy_((_x[i] - min_) * (1 / (max_ - min_)))
    return out

but the computational time increases too much. Someone has any ideas on how to improve the performance and how to remove the indexing?

Hello Maluma!

As a general rule, computations on tensors in pytorch run faster if
you use built-in tensor operations, rather than looping over indicies.

Try running torch.flatten() on the last two dimensions of your tensor
to get a tensor of shape(BxCxL), where L = W*H.

Then do your computations with things like torch.max(), using its
dim argument to specify that you want to take the max over your
tensor’s last dimension (W*H).

Give it a try, and if you have issues, post your code (working or
not) together with any errors.

(If you do get this working, it might be nice to post a follow-up with
before-and-after timings so we can see how much it helped.)

Good luck.

K. Frank

Thanks for the tip. I wrote this code:

def normalize_channels( _x, inplace=False):
        tmp = torch.flatten(_x, start_dim=2)
        _min = tmp.min(dim=-1)[0]
        _max = tmp.max(dim=-1)[0]
        del tmp
        if inplace:
            return _x.add_(_min.unsqueeze(-1).unsqueeze(-1), alpha=-1).mul_(
                1 / _max.add(_min, alpha=-1).unsqueeze(-1).unsqueeze(-1))
            return _x.add(_min.unsqueeze(-1).unsqueeze(-1), alpha=-1).mul(
                1 / _max.add(_min, alpha=-1).unsqueeze(-1).unsqueeze(-1))

And it performs better:

0 2,42E-01 7,80E-03 2,34E-01
1 2,38E-01 8,23E-03 2,30E-01
2 2,34E-01 8,17E-03 2,26E-01
3 2,26E-01 8,22E-03 2,18E-01
4 2,24E-01 8,21E-03 2,16E-01
5 2,54E-01 8,19E-03 2,46E-01
mean 2,36E-01 8,14E-03 2,28E-01