Transforms for TensorDataset()

I am new to Pytorch :baby: and I want to load the dataset using TensorDataset()

train_dataset = TensorDataset(X_train, y_train)

test_dataset = TensorDataset(X_val, y_val)

how to add transforms to this method? as I have already checked the normal dataset class and it was working normally.

Thanks :heart:

TensorDataset doesn’t accept transformations and will just index the passed tensors.
If you want to apply transformations on these tensors you could create a custom Dataset, e.g. via:

class MyDataset(Dataset):
    def __init__(self, data, target, transform=None, target_transform=None): = data = target
        self.transform = transform
        self.target_transform = target_transform
    def __len__(self):
        return len(
    def __getitem__(self, index):
        x =[index]
        if self.transform:
            x = self.transform(x)
        y =[index]
        if self.target_transform:
            y = self.target_transform(y)
        return x, y
1 Like

Too cool! I am surprised by the torch’s flexibility actually :heart_eyes: