python 3.6 with torch at 1.0.1.post2
I’m training a simple network on FashionMNIST and am getting two points of accuracy difference (89% compared to 91%) when I use my own Dataset instance vs using the torchvision. There should be no difference. I have a gist here. One can run it like python3.6 runtest.py 1
to use the torchvision dataset and python3.6 runtest.py 2
to use the custom one.
Key differences are, using torchvision dataset
dset_train = torchvision.datasets.FashionMNIST(DATA, transform=torchvision.transforms.ToTensor(), download=True)
dset_test = torchvision.datasets.FashionMNIST(DATA, transform=torchvision.transforms.ToTensor(), train=False, download=True)
train_loader = torch.utils.data.DataLoader(dataset=dset_train, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True, drop_last=True)
test_loader = torch.utils.data.DataLoader(dataset=dset_test, batch_size=256, shuffle=False, num_workers=2, pin_memory=True, drop_last=True)
run_training(train_loader, test_loader)
using my Dataset class
tmp_dset_train = torchvision.datasets.FashionMNIST(DATA, transform=torchvision.transforms.ToTensor(), download=True)
tmp_dset_test = torchvision.datasets.FashionMNIST(DATA, transform=torchvision.transforms.ToTensor(), train=False, download=True)
dset_train = FMNIST(tmp_dset_train.data, tmp_dset_train.targets)
dset_test = FMNIST(tmp_dset_test.data, tmp_dset_test.targets)
train_loader = torch.utils.data.DataLoader(dataset=dset_train, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True, drop_last=True)
test_loader = torch.utils.data.DataLoader(dataset=dset_test, batch_size=256, shuffle=False, num_workers=2, pin_memory=True, drop_last=True)
run_training(train_loader, test_loader, True) # True because need to do unsqueeze in forward pass
See the gist the FMNIST definition. The point is though that there is basically nothing happened there, so there should be no difference in performance, and yet there is one (quite a big one in fact!). I suspect that it has something to do with reshaping, but I’m not sure how that could have gone wrong.
What is going on?