TensorDataset in Pytorch

Hi All ,

I’m a beginner and just started learning pytorch.I saw that from Kaggle Kernel , the usage of TensorDataset . but i couldn’t understand what exactly it does.

The description states : indexing tensors along dimensions
but i have no clue what does this mean , Can anyone please explain?

Thanks in adv!!

If you are beginner in PyTorch, I would recommend you to read this blog post about tensor indexing.

You can think of TensorDataset as combining all your tensors as a single tensor. And then you can index them collectively,

x = torch.ones(10, 10)
y = torch.zeros(10, 5)
data = torch.utils.data.TensorDataset(x, y)
data[0]
(tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]),
 tensor([0., 0., 0., 0., 0.]))

So , do you mean TensorDataset merges the lists of different tensors?

Sort of. But you must have the dimension 0 of every tensor same. And it does not merge the tensors, it just allows you to use data[0] to index all the tensors at the same time. So you still need a tuple to access the elements of the tensors.

In the above example, you can see data[0] returns two tensors as a tuple. The first tensor is from x[0] and the second tensor is from y[0]