Lightning Module crashes RAM

Hello,

I am working on google colab and if I try to re-initialise my lightning module, it crashes the RAM. If I want to make changes to the module, I am having to restart the colab kernel.

This is my module

class PL_MODULE(pl.LightningModule):
    def __init__(self, data_path = "./", batch_size = 512, learning_rate = 0.0001, number_of_cpus: int = 4, hidden_size = 15, num_layers = 5, dropout = 0.1):
        super(PL_MODULE, self).__init__()
        self.model = ANN(hidden_size, num_layers, dropout)
        self.batch_size = batch_size
        self.learning_rate = learning_rate
        self.number_of_cpus = number_of_cpus
        self.loss = nn.CrossEntropyLoss()

        combined_metric = torchmetrics.MetricCollection([torchmetrics.Accuracy(), torchmetrics.Precision(num_classes=11, average='macro'), torchmetrics.F1(num_classes=11, average='macro')])
        self.train_metics = combined_metric.clone()
        self.val_metics = combined_metric.clone()
        self.test_metics = combined_metric.clone()
        self.confmat = torchmetrics.ConfusionMatrix(num_classes=11)

        with open(data_path + 'x.npy', 'rb') as f:
            x = np.load(f)
            f.close()
        with open(data_path + 'y.npy', 'rb') as f:
            y = np.load(f).flatten() - 1
            f.close()

        X_train, self.X_test, y_train, self.y_test = train_test_split(x, y, test_size=0.3, random_state=42)
        self.X_train, self.X_val, self.y_train, self.y_val = train_test_split(X_train, y_train, test_size=0.3, random_state=42)

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        """
        Forward step in the model
        Args:
            input (torch.Tensor): input tensor for forward step
        Returns:
            torch.Tensor: prediction tensor after forward step
        """
        return self.model(input)
    
    def train_dataloader(self) -> torch.utils.data.DataLoader:
        """Converts training data to tensor and inserts into DataLoader
        Returns:
            torch.utils.data.DataLoader: DataLoader with training data
        """
        tensor_x_train = Variable(torch.from_numpy(self.X_train.astype(float))).type(torch.FloatTensor)
        tensor_y_train = Variable(torch.from_numpy(self.y_train.astype(int))).type(torch.torch.LongTensor).flatten()
        train_dataset = TensorDataset(tensor_x_train, tensor_y_train)
        train_loader = DataLoader(
            dataset=train_dataset,
            batch_size=self.batch_size,
            num_workers=self.number_of_cpus,
            pin_memory=True,
        )
        return train_loader

    def val_dataloader(self) -> torch.utils.data.DataLoader:
        """Converts validation data to tensor and inserts into DataLoader
        Returns:
            torch.utils.data.DataLoader: DataLoader with validation data
        """
        tensor_x_validation = Variable(torch.from_numpy(self.X_val.astype(float))).type(torch.FloatTensor)
        tensor_y_validation = Variable(torch.from_numpy(self.y_val.astype(int))).type(torch.LongTensor).flatten()

        val_dataset = TensorDataset(tensor_x_validation, tensor_y_validation)
        val_loader = DataLoader(
            dataset=val_dataset,
            batch_size=self.batch_size,
            num_workers=self.number_of_cpus,
            pin_memory=True,
        )
        return val_loader

    def test_dataloader(self) -> torch.utils.data.DataLoader:
        """Converts test data to tensor and inserts into DataLoader
        Returns:
            torch.utils.data.DataLoader: DataLoader with test data
        """
        tensor_x_test = Variable(torch.from_numpy(self.X_test.astype(float))).type(torch.FloatTensor)
        tensor_y_test = Variable(torch.from_numpy(self.y_test.astype(int))).type(torch.LongTensor).flatten()

        test_dataset = TensorDataset(tensor_x_test, tensor_y_test)
        test_loader = DataLoader(
            dataset=test_dataset,
            batch_size=self.batch_size,
            num_workers=self.number_of_cpus,
            pin_memory=True,
        )
        return test_loader

    def configure_optimizers(self) -> torch.optim.Adam:
        """Configures and returns the optimizer
        Returns:
            torch.optim.Adam: Adams optimizer
        """
        return optim.Adam(self.parameters(), lr=self.learning_rate)

    def training_step(self, train_batch:torch.FloatTensor, batch_idx:int):
        '''
        Returns a dictionary with two indices-
        loss - RMSE loss for the training step
        difference - difference between predicted and output label used to calculate RMSE loss for epoch later.
        '''
        x, y = train_batch
        logits = self.forward(x)
        loss = self.loss(logits, y)
        self.log("train_loss", loss, on_step=True, on_epoch=True)
        self.log("training_scores",self.train_metics(logits.argmax(dim=1), y), on_step = False, on_epoch = True)
        return {'loss':loss, 'logits': logits, 'y': y}

    def validation_step(self, val_batch:torch.FloatTensor, batch_idx:int):
        '''
        Returns a dictionary with two indices-
        loss - RMSE loss for the validation step
        difference - difference between predicted and output label          used to calculate RMSE loss for epoch later.
        '''
        x, y = val_batch
        logits = self.forward(x)
        loss = self.loss(logits, y)
        self.log("val_loss", loss, on_step=False, on_epoch=True)
        self.log("validation_scores",self.val_metics(logits.argmax(dim=1), y), on_step = False, on_epoch = True)
        return {'loss':loss, 'logits': logits, 'y': y}

    def test_step(self, test_batch:torch.FloatTensor, batch_idx:int):       
        '''
        Returns a dictionary with two indices-
        loss - RMSE loss for the testing step
        difference - difference between predicted and output label used to calculate RMSE loss for epoch later.
        '''
        x, y = test_batch
        logits = self.forward(x)
        loss = self.loss(logits, y.long())
        self.log("test_loss", loss, on_epoch = True, on_step = False)
        self.log("testing_scores",self.test_metics(logits.argmax(dim=1), y), on_step = False, on_epoch = True)
        wandb.log({"confusion_matrix": wandb.plot.confusion_matrix(probs = None, y_true = logits.argmax(dim=1).cpu().numpy(), preds = y.cpu().numpy())})
        return {'loss':loss, 'logits': logits, 'y': y}

and this is my model

class ANN(nn.Module):

    """ANN model
    Args:
        nn (nn.Module): nn.Module wrapper.
    """

    def __init__(self, hidden_size = 15, num_layers = 5, dropout = 0.1):

        """
            Args:
                input_parameters (int): number of input parameters
                output_parameters (int): number of output parameters
        """
        super(ANN, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.dropout = dropout
        self.lstm = torch.nn.LSTM(input_size = 6,  hidden_size = self.hidden_size, num_layers = self.num_layers,  dropout = self.dropout,  batch_first=True)
        self.relu = torch.nn.ReLU()
        self.linear_1 = torch.nn.Linear(2000*self.hidden_size, 1024)
        self.linear_2 = torch.nn.Linear(1024, 11)

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        output, _ = self.lstm(input)
        output = self.relu(output)
        output = self.linear_1(output.contiguous().view(-1, 2000*self.hidden_size))
        output = self.linear_2(output)
        return output

What should i do to fix this? I am not able to perform hyperparameter tuning because of this.

@ptrblck any idea about this?

I guess “crashes the RAM” means that you are running out of host memory?
If so, did you try to delete the previous LightningModule as it seems to load the entire dataset in its __init__ method, so it could use more memory compared to using the Dataset and DataLoader outside of this class.

I had tried this

model = PL_MODULE(hidden_size = 50, learning_rate = 0.0001, num_layers = 5, batch_size = 128)   
del model

if this is what you are referring to but it didn’t work. I could not understand your comment on loading the data. Is it inefficient to do this way? Because it doesn’t crash outside of the lightning module.