Hi, my CPU memory consumption gradually increases during training. I use the PyTorch Lightning library. I am training a temporal model, where each data entry is a 2-tuple: (label, a tensor of 15 images stacked).
The model itself is quite simple: a ViT-inspired architecture. The stacked images each goes through a pretrained encoder, and a class token will be concatenated to these features, where the entire sequence enters the Transformer blocks for a classification task. The dataset is quite large (~1.3TB) and cannot be loaded into memory completely. Each sequence is 15 images.
The DataLoader looks like this:
@profile
def __getitem__(self, idx):
entry = self.csv.iloc[idx, :]
label, temporal_paths = entry.iloc[0], entry.iloc[3].replace(" ", "").replace("'", "").replace("\\\\", "/").lstrip("[").rstrip("]").split(",")
# want to randomly pick a size for each batch
if self.mode == "train":
if self.batch_count >= self.batch_size:
self.image_size = random.choice(self.image_sizes)
self.batch_count = 1
else:
self.batch_count += 1
transform = transforms.Compose(
[
transforms.Resize((self.image_size, self.image_size)),
transforms.ToTensor(),
]
)
stacked_images = torch.zeros((len(temporal_paths), 3, self.image_size, self.image_size))
for i in range(len(temporal_paths)):
with Image.open(temporal_paths[i]) as img:
img_tensor = transform(img.convert("RGB"))
stacked_images[i] = img_tensor
entry, temporal_paths, img_tensor = None, None, None
del entry, temporal_paths, img_tensor
gc.collect()
return label, stacked_images
The training_step looks like this:
@profile
def training_step(self, batch, batch_idx):
label, temporal_images = batch[0].to(self.device), batch[1].to(self.device)
output = self.forward(temporal_images)
loss = nn.CrossEntropyLoss()(output, label)
self.train_loss += loss.item()
self.correct += torch.sum(torch.argmax(output, 1) == label).item()
self.total += label.size(0)
temporal_images = None
del temporal_images
gc.collect()
return loss
Finally, forward looks like this:
@profile
def forward(self, x): # x = 4, 15, 3, 224, 224
features = torch.zeros(self.args.batch_size, self.args.seq_len, self.encoder.last_linear.in_features, device=self.device)
for i in range(x.shape[1]):
features[:, i, :] = self.proj_layer(self.encoder(x[:, i, :, :, :])[1])
batch_size = features.size(0)
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
features = torch.cat(
(cls_tokens, features), dim=1
) # [b, 15+1, input_dim] or [b, 9+1, input_dim]
features = features + self.positional_encodings
features = self.transformer_encoder(features)
cls_output = features[:, 0, :]
output = self.mlp(cls_output)
return output
Below is a collage of all memory profiles, the mem consumption plot, and packages installed:
(Am new user, one media only )
In forward(), self.encoder() is the culprit of the increase. The plot shows the situation of the first 30 batches, using Lightning’s
limit_train_batches
.
Sometimes I observe a huge decrease in mem consumption e.g. -701MiB or something, but overall the memory consumption increases. I have tried many things to no avail; such as this thread and many other posts I stumbled upon. I am at my wit’s end, absolutely devastated. Any help or point of direction is appreciated. Please let me know if more info is needed.
@ptrblck shameless tag