when i use the function: prefetch, i found the the code runtime error
The code like this:
class data_prefetch():
def init(self, cfg, loader, is_train):
self.loader = loader
def preload(self):
try:
self.next_meta = next(self.loader)
except:
self.next_meta = None
return
the function run and come to the except.
or like this:
class data_prefetch():
def init(self, cfg, loader, is_train):
self.loader = iter(loader)
def preload(self):
try:
self.next_meta = next(self.loader)
when i used the second class, the object come to runtime error?
why and how to solve this problem?
import torchvision.transforms as transforms
class data_prefetch():
def init(self, cfg, loader, is_train):
self.loader = loader
self.is_train = is_train
self.preload()
def preload(self):
try:
self.next_meta = next(self.loader)
print("try is okey")
except:
self.next_meta = None
def next(self):
meta = self.next_meta
self.preload()
return meta
def __len__(self):
return len(self.loader)
Calling next
on an iterator
is expected to raise a StopIteration
which you would need to handle as seen here:
dataset = TensorDataset(torch.randn(5))
loader = DataLoader(dataset)
iter_loader = iter(loader)
for _ in range(6):
x = next(iter_loader)
# StopIteration
Thanks,maybe other question about custum dataset or dataloader.