You could add the normalization in the __getitem__
function of your Dataset
:
class MyDataset(Dataset):
def __init__(self, X, y, transform=None):
self.data = X
self.target = y
self.transform = transform
def __getitem__(self, index):
x = self.data[index]
y = self.target[index]
# Normalize your data here
if self.transform:
x = self.transform(x)
return x, y
def __len__(self):
return len(self.data)
In this use case, you could set transform
to something like this:
transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])