CPU RAM gets full in the first epoch, while GPU's one is not utilised!

I’m using Colab Pro, along with Pytorch.

The GPU allocated for me is: Tesla P100-PCIE-16GB

However, CPU RAM filled in the first epoch, and it doesn’t even reach the validation step, my dataset is a set of images of size 128*128. Training set: 122k. valid set: 21k

Here’s my code:

def train(model, train_loader, validation_loader):

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    model = model.to(device)

    loss_train_log = []
    loss_val_log = []
    epoch_val_log = []

    print('START TRAINING...')
    for epoch in range(1, num_epochs + 1):
        # Training
        model.train()
        epoch_loss = 0
        for batch_idx, batch_samples in enumerate(train_loader):
            img = batch_samples[0].view(batch_samples[0].size(0),
                                        batch_samples[0].size(1),
                                        batch_samples[0].size(2),
                                        batch_samples[0].size(3)).to(device)
            labels = batch_samples[1].to(device)

            optimizer.zero_grad()

            prd = model(img)

            loss = F.mse_loss(prd.view(-1).float(), labels.view(-1).float())

            loss.backward()
            optimizer.step()

            epoch_loss += loss.detach().item() * len(batch_samples[0]) / len(train_loader.sampler)
            
            del loss
            del prd

        loss_train_log.append(epoch_loss)

        del epoch_loss

        print('+ TRAINING \tEpoch: {} \tLoss: {:.6f}'.format(epoch, loss_train_log[-1]))

        # Validation
        if epoch == 1 or epoch % val_interval == 0:
            print("Validation")
            model.eval()
            loss_val = 0


            with torch.no_grad():
                for data_sample in validation_loader:
                    img = data_sample[0].to(device)
                    labels = data_sample[1].to(device)
                    img = img.view(img.size(0),img.size(1),img.size(2),img.size(3))
                    
                    prd = model(img)

                    prd_flat = prd.view(-1)
                    true = labels.view(-1)

                    loss_val += F.mse_loss(prd_flat.float(), true.float()).item() * len(data_sample[0]) / len(validation_loader.sampler)
                    del prd
                loss_val_log.append(loss_val)
                epoch_val_log.append(epoch)
                del loss_val


            print('-------------------------------------------------------------')
            print('+ VALIDATE \tEpoch: {} \tLoss: {:.5f}'.format(epoch, loss_val_log[-1]))
            print('-------------------------------------------------------------')


    print('\nFinished TRAINING.')


train_dataset=dataset(X_train,y_train, transform=transform) # dataset is a class I defined.
valid_dataset=dataset(X_valid,y_valid, transform=transform)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
validation_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=batch_size, shuffle=True)

model = Model()
train(model, train_loader, validation_loader)

Could someone help me with this, please?

Here is my dataset class along with my transform:


transform = T.Compose([
                T.ToPILImage(),
                T.ToTensor(),
                T.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)),
            ])   

class dataset(Dataset):
    def __init__(self, data, label, transform=None):
        self.data=data
        self.label=label
        self.transform=transform
    def __getitem__(self, index):
        if self.transform:
            self.data[index]=self.transform(self.data[index])
        return self.data[index],self.label[index]
    def __len__(self):
        return len(self.data)

can you print your device and data inside the training. Just to. confirm weather the data itself is getting transferred to the GPU. When you print the data t will show tensor ending with cuda device id

Thanks for your response.

I figured out the problem, I was using list to store the data instead of NumPy Array.

I documented everything I tried in this post, if someone is interested: