CPU memory consumption gradually increases during training

Hi, my CPU memory consumption gradually increases during training. I use the PyTorch Lightning library. I am training a temporal model, where each data entry is a 2-tuple: (label, a tensor of 15 images stacked).

The model itself is quite simple: a ViT-inspired architecture. The stacked images each goes through a pretrained encoder, and a class token will be concatenated to these features, where the entire sequence enters the Transformer blocks for a classification task. The dataset is quite large (~1.3TB) and cannot be loaded into memory completely. Each sequence is 15 images.

The DataLoader looks like this:

    @profile
    def __getitem__(self, idx):
        entry = self.csv.iloc[idx, :]
        label, temporal_paths = entry.iloc[0], entry.iloc[3].replace(" ", "").replace("'", "").replace("\\\\", "/").lstrip("[").rstrip("]").split(",")

        # want to randomly pick a size for each batch
        if self.mode == "train":
            if self.batch_count >= self.batch_size:
                self.image_size = random.choice(self.image_sizes)
                self.batch_count = 1
            else:
                self.batch_count += 1

        transform = transforms.Compose(
            [
                transforms.Resize((self.image_size, self.image_size)),
                transforms.ToTensor(),
            ]
        )

        stacked_images = torch.zeros((len(temporal_paths), 3, self.image_size, self.image_size))

        for i in range(len(temporal_paths)):
            with Image.open(temporal_paths[i]) as img:
                img_tensor = transform(img.convert("RGB"))
                stacked_images[i] = img_tensor
        
        entry, temporal_paths, img_tensor = None, None, None
        del entry, temporal_paths, img_tensor
        gc.collect()
        
        return label, stacked_images

The training_step looks like this:

    @profile
    def training_step(self, batch, batch_idx):
        label, temporal_images = batch[0].to(self.device), batch[1].to(self.device)

        output = self.forward(temporal_images)
        loss = nn.CrossEntropyLoss()(output, label)

        self.train_loss += loss.item()
        self.correct += torch.sum(torch.argmax(output, 1) == label).item()
        self.total += label.size(0)

        temporal_images = None
        del temporal_images
        gc.collect()

        return loss

Finally, forward looks like this:

    @profile
    def forward(self, x): # x = 4, 15, 3, 224, 224
        features = torch.zeros(self.args.batch_size, self.args.seq_len, self.encoder.last_linear.in_features, device=self.device)
        
        for i in range(x.shape[1]):
            features[:, i, :] = self.proj_layer(self.encoder(x[:, i, :, :, :])[1])

        batch_size = features.size(0)

        cls_tokens = self.cls_token.expand(batch_size, -1, -1)

        features = torch.cat(
            (cls_tokens, features), dim=1
        )  # [b, 15+1, input_dim] or [b, 9+1, input_dim]
        features = features + self.positional_encodings

        features = self.transformer_encoder(features)

        cls_output = features[:, 0, :]

        output = self.mlp(cls_output)

        return output

Below is a collage of all memory profiles, the mem consumption plot, and packages installed:


(Am new user, one media only :smiling_face_with_tear:)
In forward(), self.encoder() is the culprit of the increase. The plot shows the situation of the first 30 batches, using Lightning’s limit_train_batches.

Sometimes I observe a huge decrease in mem consumption e.g. -701MiB or something, but overall the memory consumption increases. I have tried many things to no avail; such as this thread and many other posts I stumbled upon. I am at my wit’s end, absolutely devastated. Any help or point of direction is appreciated. Please let me know if more info is needed.

@ptrblck shameless tag

I don’t see anything obviously wrong in your code. You could start commenting out parts of your code and check when the increase in memory disappears. Usually, appending a tensor to a list or another container (which is still attached to the computation graph) is the cause for such a behavior, but since you are calling .item() on tensors you are attaching it should be fine.

Thanks! Good to hear nothing is obviously wrong with my code. At least I know I’m in the right direction.