To understand reductions see this and related links:
# https://towardsdatascience.com/understanding-dimensions-in-pytorch-6edf9972d3be
# dimension
"""
Dimension reduction. It collapses/reduces a specific dimension by selecting an element from that dimension to be
reduced.
Consider x is 3D tensor. x.sum(1) converts x into a tensor that is 2D using an element from D1 elements in
the 1th dimension. Thus:
x.sum(1) = x[i,k] = op(x[i,:,k]) = op(x[i,0,k],...,x[i,D1,k])
the key is to realize that we need 3 indices to select a single element. So if we use only 2 (because we are collapsing)
then we have D1 number of elements possible left that those two indices might indicate. So from only 2 indices we get a
set that we need to specify how to select. This is where the op we are using is used for and selects from this set.
In theory if we want to collapse many indices we need to indicate how we are going to allow indexing from a smaller set
of indices (using the remaining set that we'd usually need).
"""
import torch
x = torch.tensor([
[1, 2, 3],
[4, 5, 6]
])
print(f'x.size() = {x.size()}')
# sum the 0th dimension (rows). So we get a bunch of colums that have the rows added together.
x0 = x.sum(0)
print(x0)
# sum the 1th dimension (columns)
x1 = x.sum(1)
print(x1)
x_1 = x.sum(-1)
print(x_1)
x0 = x.max(0)
print(x0.values)
y = torch.tensor([[
[ 1, 2, 3, 4],
[ 5, 6, 7, 8],
[ 9, 10, 11, 12]],
[[13, 14, 15, 16],
[17, 18, 19, 20],
[21, 22, 23, 24]]])
print(y)
# into the screen [1, 13]
print(y[:,0,0])
# columns [1, 5, 9]
print(y[0,:,0])
# rows [1, 2, 3, 4]
print(y[0,0,:])
# for each remaining index, select the largest value in the "screen" dimension
y0 = y.max(0)
print(y0.values)