The batch that I’m expecting has three elements with the shapes: (64,1000), (64) and (64,100)
Instead, I’m getting all 80.000 elements in the training data ((80.000,1000), (80.000 ) and (80.000 ,100)). What am I doing wrong? On another note: is there a way of accessing the elements in the batch by attribute instead of key?
from torch.utils.data import DataLoader, TensorDataset
import torch
class dataset(TensorDataset):
def __init__(self):
self.X = None
self.y_head1 = None
self.y_head2 = None
self.process()
def __len__(self):
return (self.X.shape[0])
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
return self.X[idx], self.y_head1[idx], self.y_head2[idx]
def process(self):
self.X = torch.rand(100000, 1000)
self.y_head1 = torch.rand(100000)
self.y_head2 = torch.rand(100000, 100)
dataset = dataset()
num_samples = len(dataset)
batch_size = 64
num_val = num_samples // 10
val_dataset = dataset[:num_val]
test_dataset = dataset[num_val:2 * num_val]
train_dataset = dataset[2 * num_val:]
def collate_fn(batch):
return {
'X': batch[0],
'y_head1': batch[1],
'y_head2': batch[2]
}
data_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=10, collate_fn=collate_fn)
batch = next(iter(data_loader))
print(batch['X'].shape)
print(batch['y_head1'].shape)
print(batch['y_head2'].shape)