The following code can reproduce this error.
import torch
from torch.utils.data import Dataset
import numpy as np
class TEST1(Dataset):
def __init__(self):
super().__init__()
self.dataset = None
self.data_len = 3
def __len__(self):
return self.data_len
def __getitem__(self, idx):
if self.dataset is None:
self.dataset = np.random.rand(3, 4)
return torch.tensor(idx) # here i return the index
class TEST2(Dataset):
def __init__(self):
super().__init__()
self.dataset = None
self.data_len = 3
def __len__(self):
return self.data_len
def __getitem__(self, idx):
if self.dataset is None:
self.dataset = np.random.rand(3, 4)
return torch.tensor(self.dataset[idx]) # here i return the data
test1 = TEST1()
test2 = TEST2()
for i, x in enumerate(test1):
if i > 5:
break
print("test1", i, x, "len = ", len(test1))
for i, x in enumerate(test2):
if i > 5:
break
print("test2", i, x, "len = ", len(test2))
Then we can see that the output is
test1 0 tensor(0) len = 3
test1 1 tensor(1) len = 3
test1 2 tensor(2) len = 3
test1 3 tensor(3) len = 3
test1 4 tensor(4) len = 3
test1 5 tensor(5) len = 3
test2 0 tensor([0.9721, 0.4401, 0.8097, 0.5749], dtype=torch.float64) len = 3
test2 1 tensor([0.3238, 0.0371, 0.4474, 0.3992], dtype=torch.float64) len = 3
test2 2 tensor([0.1020, 0.2405, 0.4203, 0.1879], dtype=torch.float64) len = 3
Actually, if I do not use break
, the loop of test1
would never stop. Why?
PyTorch 1.4.0 and PyTorch 1.3.1 both have this problem.