Hi,
My issue is similar to this solved one, but I am wondering about why the DataLoader returns a list instead of a tuple.
Code for reproducing (PyTorch 2.0.1):
from typing import Any, Optional
import torch
from torch.utils.data import Dataset, DataLoader
class ImageDataset(Dataset):
def __init__(self, imgs: torch.Tensor, targets: torch.Tensor, img_transform: Optional[Any] = None) -> None:
"""Image dataset class."""
super().__init__()
self.imgs = imgs
self.targets = targets
self.img_transform = img_transform
def __len__(self) -> int:
assert len(self.imgs) == len(self.targets), "Number of images and targets must be equal"
return len(self.imgs)
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
img = self.imgs[idx]
if self.img_transform is not None:
img = self.img_transform(img)
target = self.targets[idx]
return img, target
if __name__ == "__main__":
imgs = torch.rand(256, 3, 32, 32)
targets = torch.randint(0, 10, (256,))
dataset = ImageDataset(imgs, targets)
dataset[0]
The type of dataset[0]
is tuple
, just as I would expect, but putting the dataset into a DataLoader yields:
data_loader = DataLoader(dataset, batch_size=1, shuffle=False)
for idx, data in enumerate(data_loader):
print(data)
break
Why is the type of data
a list
now? I would still expect a tupleā¦
Best,
Imahn