from torch.utils.data import Dataset
from sklearn.datasets import fetch_openml
# Load data from https://www.openml.org/d/554
X, y = fetch_openml('mnist_784', version=1, return_X_y=True)
print(X.shape)
class SimpleDataset(Dataset):
def __init__(self, X, y):
super(SimpleDataset, self).__init__()
self.X = X
self.y = y
def __getitem__(self, index):
#This "work" could have gone in the constructor, but you should get into
inputs = torch.tensor(self.X[index,:], dtype=torch.float32)
targets = torch.tensor(int(self.y[index]), dtype=torch.int64)
return inputs, targets
def __len__(self):
return self.X.shape[0]
#Now we can make a PyTorch dataset
dataset = SimpleDataset(X, y)
print("Length: ", len(dataset))
example, label = dataset[0]
print("Features: ", example.shape) #Will return 784
print("Label of index 0: ", label)
The screenshot shows an indexing error raised by pandas so check where a pandas.DataFrame is used and how it’s being indexed to narrow down the root cause of the wrong indexing value.