Hi,
I’m currently trying to train a basic CNN on the CIFAR10 dataset, which I loaded using train_dataset = torchvision.datasets.CIFAR10(DATA_PATH, train=True, transform=transform, download=True), and was able to achieve decent accuracy. However, I noticed that I was tuning the hyperparameters to the test set and seeing the response, possibly overfitting it. So, I looked into using the GridSearchCV class of Sklearn to find the best combination of hyperparameters through K-Fold cross validation. Code below:
In the process, I was unable to find a way to split the train_dataset into X (images) and Y(labels) for gs.fit() which threw this error TypeError: fit() missing 1 required positional argument: 'y'. Does anyone know a way to split the dataset back into images and labels?
Take a look at torchvision.datasets.CIFAR10. As you can see the __getitem__ returns a tuple with (image, target), where target corresponds to the class/label of the image. You can iterate through the dataset in the following manner:
...
dataset = torchvision.datasets.CIFAR10(DATA_PATH, train=True, transform=transform, download=True)
for i, data in enumerate(dataset): # i == Index
image, label = data
...
Now you have your image as well as the corresponding label and you can work with it.
Hi boto,
Sorry for the late reply. I tried split the dataset into x and y by
def split_XY(dataset):
l = []
a = torch.Tensor(50000, 3, 32, 32)
for i, (image, label) in enumerate(train_dataset):
a[i, :, :, :] = image
l.append(label)
return a, torch.Tensor(l)
and tried
x, y = split_XY(train_dataset)
gs.fit(x, y)
but when I ran it, it threw ValueError: Cannot perform a CV split if dataset and y have different lengths. which didn’t make sense to me because x is 50000 x 3 x 32 x 32 and y is 50000 x 1. Maybe I’m missing something? Thanks!
Unfortunately, I don’t have any experience with GridSearchCV.
But the error states that x and y do not have the length/shape, which is true as x.shape=(500000,3,32,32) and y.shape=(50000, 1). That’s all I can tell you. Maybe try looking at Skorch FAQ.