I have implemented a very simple Iterable Dataset in PyTorch :
class Maps(data.IterableDataset):
def __init__(self):
super().__init__()
def yield_items(self):
for i in range(150):
file = np.random.rand(160, 256, 512)
for j in range(150):
img = file[j:j+10, :, :256]
label = file[j:j+10, :, 256:]
img = torch.tensor(img/np.max(img))
label = torch.tensor(label / np.max(label))
yield img.float(), label.float()
def cycle_data(self):
return cycle(self.seeItems())
def __iter__(self):
return self.cycle_place()
train_dataset = Maps()
data_loader = data.DataLoader(train_dataset)
for x,y in data_loader:
continue
I am running this code on google colab and it causes the RAM to fill up real quick and the runtime crashes. Does someone know what this would be due to?
Edit : If I remove the loop on j it works fine. But in my actual use case that loop is inevitable.