x[separation:]
is slicing an array (or tensor in this case). Basically, separation
is a limit equal to 80% of the whole dataset, since we want to keep 80% of the data for training and the remaining 20% for testing.
x[:separation]
takes all the elements of the tensor up to separation
(so 0, 1, 2, 3, …, separation - 1)
x[separation:]
takes all the elements of the tensor from separation
to the length of the tensor (so separation, separation + 1, …, 65534, 65535)
Here is a small example with a 5 x 5
tensor:
x = torch.rand(size=(5, 5), dtype=torch.float32)
print(x)
# tensor([[0.0345, 0.3828, 0.2489, 0.4129, 0.4522],
# [0.0787, 0.7049, 0.2124, 0.2115, 0.1857],
# [0.6836, 0.7091, 0.2063, 0.1679, 0.3338],
# [0.7525, 0.1769, 0.1104, 0.0380, 0.6871],
# [0.9377, 0.6564, 0.2296, 0.5100, 0.7274]])
separation = 3
# first 3 rows
y = x[:separation]
print(y)
# tensor([[0.0345, 0.3828, 0.2489, 0.4129, 0.4522],
# [0.0787, 0.7049, 0.2124, 0.2115, 0.1857],
# [0.6836, 0.7091, 0.2063, 0.1679, 0.3338]])
# last 2 rows
z = x[separation:]
print(z)
# tensor([[0.7525, 0.1769, 0.1104, 0.0380, 0.6871],
# [0.9377, 0.6564, 0.2296, 0.5100, 0.7274]])
# rows 1, 2, 3
d = x[1:4]
print(d)
# tensor([[0.0787, 0.7049, 0.2124, 0.2115, 0.1857],
# [0.6836, 0.7091, 0.2063, 0.1679, 0.3338],
# [0.7525, 0.1769, 0.1104, 0.0380, 0.6871]])
No specific reason, except that I forgot!