kuba
February 21, 2022, 8:21pm
1
Hi,
I use a custom Dataset (with all elements: init , getitem and len ) using data.Dataset that I save with torch.save(dataset, 'dataset.pt')
. Dataset has a size of 10 samples with 5 images each. Before saving it, I test it by enumerating all samples and checking all 5 images per sample with
for index, (image1, image2, image3, image4, image5) in enumerate(dataset):
Later, when I load the same dataset with torch.load('dataset.pt')
, its length is 2. Moreover, when I use the same line for enumerating for index, (image1, image2, image3, image4, image5) in enumerate(dataset):
, it shows me ValueError: too many values to unpack (expected 5)
.
What do I miss in my process?
ptrblck
February 21, 2022, 8:55pm
2
Could you post your Dataset
definition, please?
I’m not sure how PyTorch serializes a Dataset
, since you would usually save the input arguments (e.g. the paths if you are lazily loading the data or the data directly) instead of the Dataset
object.
kuba
February 21, 2022, 10:57pm
4
It is like this class TrainingDataset(data.TensorDataset)
, but I believe you are right about serializing as one of my friends suggested something similar.
ptrblck
February 21, 2022, 11:15pm
5
OK, that’s still strange, as I’m able to restore the TensorDataset
:
x, y = torch.randn(10, 1), torch.randn(10, 1)
dataset = torch.utils.data.TensorDataset(x, y)
torch.save(dataset, 'dataset.pt')
dataset_loaded = torch.load('dataset.pt')
for (x1, y1), (x2, y2) in zip(dataset, dataset_loaded):
print('input diff {}'.format((x1 - x2).abs().max()))
print('target diff {}'.format((y1 - y2).abs().max()))
print((dataset.tensors[0] == dataset_loaded.tensors[0]).all())
print((dataset.tensors[1] == dataset_loaded.tensors[1]).all())
All checks pass and no mismatch is detected.
kuba
February 21, 2022, 11:31pm
6
So I’m working on someone’s code, and this is the full code:
class TrainingDataset(data.Dataset):
def __init__(self, datapath):
self.numpyimage = []
self.numpylabel_LA = []
self.numpylabel_LAdist = []
self.numpyprob_normal = []
self.numpyprob_scar = []
self.numpyprob_background = []
self.NumOfSubjects = 0
self.datafile = glob.glob(datapath + '/*')
for subjectid in range(len(self.datafile)):
#if subjectid > 1:
# break
imagename = self.datafile[subjectid] + '/enhanced.nii.gz'
LAlabelname = self.datafile[subjectid] + '/atriumSegImgMO.nii.gz'
LAscarMaplabelname = self.datafile[subjectid] + '/scarSegImgM_wall.nii.gz'
print('loading training image: ' + imagename)
numpyimage, numpylabel_LA, numpylabel_LAdist, numpyprob_normal, numpyprob_scar = LoadDataset_scar(imagename, LAlabelname, LAscarMaplabelname)
self.numpyimage.extend(numpyimage)
self.numpylabel_LA.extend(numpylabel_LA)
self.numpylabel_LAdist.extend(numpylabel_LAdist)
# self.numpyprob_background.extend(numpy2Dbackgroundprob)
self.numpyprob_normal.extend(numpyprob_normal)
self.numpyprob_scar.extend(numpyprob_scar)
self.NumOfSubjects += 1
def __getitem__(self, item):
numpyimage = np.array([self.numpyimage[item]])
numpylabel_LA = np.array([self.numpylabel_LA[item]])
numpylabel_LA = (numpylabel_LA > 0) * 1
numpylabel_LAdist = np.array([self.numpylabel_LAdist[item]])
# numpyprob_background = np.array([self.numpyprob_background[item]])
numpyprob_normal = np.array([self.numpyprob_normal[item]])
numpyprob_scar = np.array([self.numpyprob_scar[item]])
tensorimage = torch.from_numpy(numpyimage).float()
tensorlabel_LA = torch.from_numpy(numpylabel_LA.astype(np.float32))
tensorlabel_LAdist = torch.from_numpy(numpylabel_LAdist.astype(np.float32))
# tensorprob_background = torch.from_numpy(numpyprob_background.astype(np.float32))
tensorprob_normal = torch.from_numpy(numpyprob_normal.astype(np.float32))
tensorprob_scar = torch.from_numpy(numpyprob_scar.astype(np.float32))
return tensorimage, tensorlabel_LA, tensorlabel_LAdist, tensorprob_normal, tensorprob_scar
def __len__(self):
return self.NumOfSubjects
ptrblck
February 21, 2022, 11:39pm
7
I don’t have the data files, but using random numpy arrays still works:
class TrainingDataset(torch.utils.data.Dataset):
def __init__(self):
print('calling init')
self.numpyimage = []
self.numpylabel_LA = []
self.numpylabel_LAdist = []
self.numpyprob_normal = []
self.numpyprob_scar = []
self.numpyprob_background = []
self.NumOfSubjects = 0
for subjectid in range(10):
# numpyimage, numpylabel_LA, numpylabel_LAdist, numpyprob_normal, numpyprob_scar = LoadDataset_scar(imagename, LAlabelname, LAscarMaplabelname)
numpyimage, numpylabel_LA, numpylabel_LAdist, numpyprob_normal, numpyprob_scar = np.random.randn(1, 1), np.random.randn(1, 1), np.random.randn(1, 1), np.random.randn(1, 1), np.random.randn(1, 1)
self.numpyimage.extend(numpyimage)
self.numpylabel_LA.extend(numpylabel_LA)
self.numpylabel_LAdist.extend(numpylabel_LAdist)
# self.numpyprob_background.extend(numpy2Dbackgroundprob)
self.numpyprob_normal.extend(numpyprob_normal)
self.numpyprob_scar.extend(numpyprob_scar)
self.NumOfSubjects += 1
def __getitem__(self, item):
numpyimage = np.array([self.numpyimage[item]])
numpylabel_LA = np.array([self.numpylabel_LA[item]])
numpylabel_LA = (numpylabel_LA > 0) * 1
numpylabel_LAdist = np.array([self.numpylabel_LAdist[item]])
# numpyprob_background = np.array([self.numpyprob_background[item]])
numpyprob_normal = np.array([self.numpyprob_normal[item]])
numpyprob_scar = np.array([self.numpyprob_scar[item]])
tensorimage = torch.from_numpy(numpyimage).float()
tensorlabel_LA = torch.from_numpy(numpylabel_LA.astype(np.float32))
tensorlabel_LAdist = torch.from_numpy(numpylabel_LAdist.astype(np.float32))
# tensorprob_background = torch.from_numpy(numpyprob_background.astype(np.float32))
tensorprob_normal = torch.from_numpy(numpyprob_normal.astype(np.float32))
tensorprob_scar = torch.from_numpy(numpyprob_scar.astype(np.float32))
return tensorimage, tensorlabel_LA, tensorlabel_LAdist, tensorprob_normal, tensorprob_scar
def __len__(self):
return self.NumOfSubjects
dataset = TrainingDataset()
torch.save(dataset, 'dataset.pt')
dataset_loaded = torch.load('dataset.pt')
for x1, x2 in zip(dataset, dataset_loaded):
for x1_, x2_ in zip(x1, x2):
print('input diff {}'.format((x1_ - x2_).abs().max()))