Indexing in a tensor

Hi there
i have been reading through deep learning with pytorch book, and i have a small question.
in this code we are trying to calculate the mean and std of the pixels to normalize them

n_channels = batch.shape[1]
for c in range(n_channels):
    mean = torch.mean(batch[:, c])
    std = torch.std(batch[:, c])
    batch[:, c] = (batch[:, c] - mean) / std

what is the meaning of batch[:, c] given that the channels have the 2nd dimension, if the channel dimension has another position like the last dimension how can i calculate the mean across the channels??

batch[:, c] will index the current channel and torch.mean as well as torch.std will calculate the mean and standard deviation for this channel, respectively.
How should the result of the stats “across the channels” look like, i.e. which shape should it have?
If you want to calculate the mean in the channel dimension, you could use torch.mean(batch, dim=1).

@ptrblck thanks a lot, that was extremely helpful :smiley: