I have an LSTM model which I want to use for time forecasting. I have a class that extends pl.LightningDataModule
and in this class I am trying to create data loaders to train my model.
Class BTCPriceDataModule
class BTCPriceDataModule(pl.LightningDataModule):
def __init__(self, train_sequences, test_sequences, batch_size = 8):
super().__init__()
self.train_sequence = train_sequences
self.test_sequences = test_sequences
self.batch_size = batch_size
def setup(self, stage=None):
self.train_dataset = BTCDataset(self.train_sequence)
self.test_dataset = BTCDataset(self.test_sequences)
def train_dataloader(self):
print("coming here")
return DataLoader(
self.train_dataset,
batch_size = self.batch_size,
shuffle = False,
num_workers=2
)
def val_dataloader(self):
return DataLoader(
self.test_dataset,
batch_size=1,
shuffle=False,
num_workers=1
)
def test_dataloader(self):
return DataLoader(
self.test_dataset,
batch_size=1,
shuffle=False,
num_workers=1
)
Class BTCDataset
class BTCDataset(Dataset):
def __init__(self,sequences):
self.sequences = sequences
def __len__(self):
return len(self.sequences)
def __getitem__(self, idx):
sequence, label = self.sequences[idx]
return dict(
sequence=torch.Tensor(sequence.to_numpy()),
label = torch.tensor(label).float()
)
When I run trainer.fit()
my sanity check takes forever to complete so I started to debug my code to check if there is any problem while processing the data. When I pass my sequential data through the DataLoader it never is able to even finish one single loop. Example:
data_module = BTCPriceDataModule(train_sequences, test_sequences, batch_size=BATCH_SIZE)
data_module.setup()
for i in data_module.train_dataloader():
print(i['sequence'].shape)
print(i['label'].shape)
break
This loop never finishes. I think I gave it enough time and also reduced my dataset size but still it would not compile. len(data_module.train_dataloader())
is 1921 so I am pretty sure it is not very heavy on my CPU to take more than 5 minutes to run. How can I debug this problem? Or is there a solution for this?