Dataloader returns full-batch instead of batch size

The batch that I’m expecting has three elements with the shapes: (64,1000), (64) and (64,100)
Instead, I’m getting all 80.000 elements in the training data ((80.000,1000), (80.000 ) and (80.000 ,100)). What am I doing wrong? On another note: is there a way of accessing the elements in the batch by attribute instead of key?

from torch.utils.data import DataLoader, TensorDataset
import torch

class dataset(TensorDataset):

    def __init__(self):
        self.X = None
        self.y_head1 = None
        self.y_head2 = None
        self.process() 

    def __len__(self):
        return (self.X.shape[0])

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        return self.X[idx], self.y_head1[idx], self.y_head2[idx]

    def process(self):
        self.X = torch.rand(100000, 1000)
        self.y_head1 = torch.rand(100000)
        self.y_head2 = torch.rand(100000, 100)

dataset = dataset() 

num_samples = len(dataset)

batch_size = 64

num_val = num_samples // 10

val_dataset = dataset[:num_val]  
test_dataset = dataset[num_val:2 * num_val] 
train_dataset = dataset[2 * num_val:] 

def collate_fn(batch):
    return {
        'X': batch[0],
        'y_head1': batch[1],
        'y_head2': batch[2]         
    }

data_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=10, collate_fn=collate_fn)

batch = next(iter(data_loader))
print(batch['X'].shape)
print(batch['y_head1'].shape)
print(batch['y_head2'].shape)

I don’t think this does what you think it does, you could use Subset for this, but you’d need to enumerate the indices.

Best regards

Thomas

Thank you Tom,

For those with a similar problem: in addition to this the batch in the collate function need to be stacked.

val_dataset = Subset(dataset, [i for i in range(num_val)])  
test_dataset = Subset(dataset, [i + num_val for i in range(num_val)])  
train_dataset = Subset(dataset, [i + 2*num_val for i in range(num_val*8)])  

def collate_fn(batch):

	return {
	    'X': torch.stack([x[0] for x in batch], dim=0),
	    'y_head1': torch.stack([x[1] for x in batch], dim=0)    ,
	    'y_head2': torch.stack([x[2] for x in batch], dim=0)       
	}

data_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=10, collate_fn=collate_fn)

batch = next(iter(data_loader))
print(batch['X'].shape)
print(batch['y_head1'].shape)
print(batch['y_head2'].shape)