Hi All ,
I’m a beginner and just started learning pytorch.I saw that from Kaggle Kernel , the usage of TensorDataset . but i couldn’t understand what exactly it does.
The description states : indexing tensors along dimensions
but i have no clue what does this mean , Can anyone please explain?
Thanks in adv!!
If you are beginner in PyTorch, I would recommend you to read this blog post about tensor indexing.
You can think of TensorDataset as combining all your tensors as a single tensor. And then you can index them collectively,
x = torch.ones(10, 10)
y = torch.zeros(10, 5)
data = torch.utils.data.TensorDataset(x, y)
(tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]),
tensor([0., 0., 0., 0., 0.]))
So , do you mean TensorDataset merges the lists of different tensors?
Sort of. But you must have the dimension 0 of every tensor same. And it does not merge the tensors, it just allows you to use data to index all the tensors at the same time. So you still need a tuple to access the elements of the tensors.
In the above example, you can see data returns two tensors as a tuple. The first tensor is from x and the second tensor is from y