Hi,
I am new to Pytorch
and I want to load the dataset using TensorDataset()
train_dataset = TensorDataset(X_train, y_train)
test_dataset = TensorDataset(X_val, y_val)
how to add transforms to this method? as I have already checked the normal dataset class and it was working normally.
Thanks 
TensorDataset
doesn’t accept transformations and will just index the passed tensors.
If you want to apply transformations on these tensors you could create a custom Dataset
, e.g. via:
class MyDataset(Dataset):
def __init__(self, data, target, transform=None, target_transform=None):
self.data = data
self.target = target
self.transform = transform
self.target_transform = target_transform
def __len__(self):
return len(self.data)
def __getitem__(self, index):
x = self.data[index]
if self.transform:
x = self.transform(x)
y = self.target[index]
if self.target_transform:
y = self.target_transform(y)
return x, y
1 Like
Too cool! I am surprised by the torch’s flexibility actually 