Transforms for TensorDataset()

I am new to Pytorch :baby: and I want to load the dataset using TensorDataset()

train_dataset = TensorDataset(X_train, y_train)

test_dataset = TensorDataset(X_val, y_val)

how to add transforms to this method? as I have already checked the normal dataset class and it was working normally.

TensorDataset doesn’t accept transformations and will just index the passed tensors.
If you want to apply transformations on these tensors you could create a custom Dataset, e.g. via:

class MyDataset(Dataset):
    def __init__(self, data, target, transform=None, target_transform=None): = data = target
        self.transform = transform
        self.target_transform = target_transform
    def __len__(self):
        return len(
    def __getitem__(self, index):
        x =[index]
        if self.transform:
            x = self.transform(x)
        y =[index]
        if self.target_transform:
            y = self.target_transform(y)
        return x, y
