Transforms for TensorDataset()

Hi,
I am new to Pytorch :baby: and I want to load the dataset using TensorDataset()

train_dataset = TensorDataset(X_train, y_train)

test_dataset = TensorDataset(X_val, y_val)

how to add transforms to this method? as I have already checked the normal dataset class and it was working normally.

Thanks :heart:

TensorDataset doesn’t accept transformations and will just index the passed tensors.
If you want to apply transformations on these tensors you could create a custom Dataset, e.g. via:

class MyDataset(Dataset):
    def __init__(self, data, target, transform=None, target_transform=None):
        self.data = data
        self.target = target
        self.transform = transform
        self.target_transform = target_transform
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        x = self.data[index]
        if self.transform:
            x = self.transform(x)
        
        y = self.target[index]
        if self.target_transform:
            y = self.target_transform(y)
            
        return x, y
1 Like

Too cool! I am surprised by the torch’s flexibility actually :heart_eyes: