TensorDataset
doesn’t accept transformations and will just index the passed tensors.
If you want to apply transformations on these tensors you could create a custom Dataset
, e.g. via:
class MyDataset(Dataset):
def __init__(self, data, target, transform=None, target_transform=None):
self.data = data
self.target = target
self.transform = transform
self.target_transform = target_transform
def __len__(self):
return len(self.data)
def __getitem__(self, index):
x = self.data[index]
if self.transform:
x = self.transform(x)
y = self.target[index]
if self.target_transform:
y = self.target_transform(y)
return x, y