GPU RAM running out on minimal model example

Hi PT-Community :hugs: first post here we go.

TLDR: How can I improve this model to run on 4GB of RAM, if possible?

There seems to be a problem with RAM consumption, as the model overflows GPU RAM (4GB) in a few seconds. The text file that it is training on is a mere 800~ lines long.
I’m looking forward to fine-tuning an MLM following James Briggs’ MLM-Tutorial on YouTube. The code is basically taken from there, with a tweak to the DataSet class. It looks like this:

import os
import torch
from torch.utils.data import Dataset
from transformers import BertTokenizer, BertForMaskedLM
from transformers import AdamW
from tqdm import tqdm

wd_in = os.getcwd() + "/data/in/"
wd_out = os.getcwd() + "/data/out/"
tokenizer = BertTokenizer.from_pretrained('bert-base-german-cased')
model = BertForMaskedLM.from_pretrained('bert-base-german-cased')

with open(wd_in + "plenar.txt", 'r', encoding='utf8') as fp:
    text = fp.read().split('.')

inputs = tokenizer(text, return_tensors='pt', max_length=512, truncation=True, padding='max_length')
inputs['labels'] = inputs.input_ids.detach().clone()

rand = torch.rand(inputs.input_ids.shape)
mask_arr = (rand < 0.15) * (inputs.input_ids != 101) * (inputs.input_ids != 102) * (inputs.input_ids != 0)

selection = []
for i in range(mask_arr.shape[0]):
    selection.append(
        torch.flatten(mask_arr[0].nonzero()).tolist()
    )

for i in range(mask_arr.shape[0]):
    inputs.input_ids[i, selection[i]] = 103


class GerParCorDS(Dataset):
    def __new__(cls, encodings, *args, **kwargs):
        print("Creating class instance")
        instance = super(GerParCorDS, cls).__new__(cls, *args, **kwargs)
        return instance

    def __init__(self, encodings, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.encodings = encodings

    def __getitem__(self, idx):
        return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}

    def __len__(self):
        return len(self.encodings.input_ids)


dataset = GerParCorDS(inputs)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=16, shuffle=True)

device = torch.device('cuda')
model.to(device)
model.train()

optim = torch.optim.AdamW(model.parameters(), lr=1e-5)

epochs = 2

# training
for epoch in range(epochs):
    loop = tqdm(dataloader, leave=True)
    for batch in loop:
        optim.zero_grad()
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        loss.backward()
        optim.step()

        loop.set_description(f'Epoch {epoch}')
        loop.set_postfix(loss=loss.item())

I also checked for any processes that may occupy RAM but there seems to be none competing with Torch: manjaro linux shell: >watch nvidia-smi

+-----------------------------------------------------------------------------+
| NVIDIA-SMI 510.73.05    Driver Version: 510.73.05    CUDA Version: 11.6     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA GeForce ...  Off  | 00000000:01:00.0  On |                  N/A |
| N/A   52C    P8     8W /  N/A |     46MiB /  4096MiB |     17%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|    0   N/A  N/A       946      G   /usr/lib/Xorg                      45MiB |
+-----------------------------------------------------------------------------+

I don’t know if this Bert model would fit into 4GB, but you could try to run it e.g. on Google Colab (or any other platform providing “free” GPUs) which might provide a GPU with more memory.
There you could then use a single sample and check the GPU memory requirement.

Thanks, I will try that and report back. Do you mean single sample as in “just try one sentence”, or is that a feature?

Just one sentence sounds good. The idea is to reduce the memory requirement as much as possible by keeping the input “small”.

TLDR;
It seems as if the problem is not the amount of input the model receives, but the batches or model parameters themselves. Maybe the pretrained model is too heavy?

1. Trying a single sample locally
The first thing I tried was modelling a single sentence on my local IDE, as before but less data. I found the line that floods the RAM. Obviously enough it is the call of the model itself. When loading the language model RAM hovers at ~1775 Mb, later when the training loop gets called its at ~3700 Mb and throws the Exception at this line of code in the transformers library, an excerpt:

        # Take the dot product between "query" and "key" to get the raw attention scores.
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))

The exception itself is being thrown at the training loop in my code at

outputs = model(input_ids, attention_mask=attention_mask, labels=labels)

which is to be expected.

2. Trying the original amount with TPU in Google Colab
As @ptrblck suggested I also tried this with a TPU in Google Colab. This implementation uses the original text file of 800 lines. It hovers at around 1300 Mb RAM, slightly less than when run locally. Tensors are being trained the first epoch up until 6% but then the TPU runs out of memory as well. As it says in the error log:

RuntimeError: RESOURCE_EXHAUSTED: From /job:tpu_worker/replica:0/task:0:
2 root error(s) found.
  (0) RESOURCE_EXHAUSTED: Ran out of memory in memory space hbm. Used 8.19G of 7.48G hbm. Exceeded hbm capacity by 721.92M.

Total hbm usage >= 8.70G:
    reserved        530.00M 
    program           7.69G 
    arguments       505.21M 

This is the full log:

  0%|          | 0/62 [00:00<?, ?it/s]/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:12: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  if sys.path[0] == '':
Epoch 0:   6%|▋         | 4/62 [05:19<1:17:13, 79.88s/it, loss=17.2]
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-16-e02a8c5faa6b> in <module>()
     14 
     15         loop.set_description(f'Epoch {epoch}')
---> 16         loop.set_postfix(loss=loss.item())
     17 
     18 db = True

RuntimeError: RESOURCE_EXHAUSTED: From /job:tpu_worker/replica:0/task:0:
2 root error(s) found.
  (0) RESOURCE_EXHAUSTED: Ran out of memory in memory space hbm. Used 8.19G of 7.48G hbm. Exceeded hbm capacity by 721.92M.

Total hbm usage >= 8.70G:
    reserved        530.00M 
    program           7.69G 
    arguments       505.21M 

Output size 5.0K; shares 0B with arguments.

Program hbm requirement 7.69G:
    global             4.0K
    HLO temp          6.81G (100.0% utilization: Unpadded (6.24G) Padded (6.24G), 8.3% fragmentation (580.78M))
    overlays        905.37M

  Largest program allocations in hbm:

  1. Size: 905.37M
     XLA label: overlays
     Allocation type: overlays
     ==========================

  2. Size: 192.00M
     Shape: f32[16,12,512,512]{3,2,1,0:T(8,128)}
     Unpadded size: 192.00M
     XLA label: %copy.5019 = f32[16,12,512,512]{3,2,1,0:T(8,128)} copy(f32[16,12,512,512]{2,3,1,0:T(8,128)} %bitcast.9613)
     Allocation type: HLO temp
     ==========================

  3. Size: 192.00M
     Shape: f32[16,12,512,512]{3,2,1,0:T(8,128)}
     Unpadded size: 192.00M
     XLA label: %copy.5028 = f32[16,12,512,512]{3,2,1,0:T(8,128)} copy(f32[16,12,512,512]{2,3,1,0:T(8,128)} %bitcast.9667)
     Allocation type: HLO temp
     ==========================

  4. Size: 192.00M
     Shape: f32[16,12,512,512]{3,2,1,0:T(8,128)}
     Unpadded size: 192.00M
     XLA label: %copy.5219 = f32[16,12,512,512]{3,2,1,0:T(8,128)} copy(f32[16,12,512,512]{2,3,1,0:T(8,128)} %bitcast.2451)
     Allocation type: HLO temp
     ==========================

  5. Size: 96.00M
     Shape: f32[16,512,3072]{2,1,0:T(8,128)}
     Unpadded size: 96.00M
     XLA label: %fusion.781.remat6 = f32[16,512,3072]{2,1,0:T(8,128)} fusion(f32[3072]{0:T(1024)} %fusion.10823, f32[16,512,768]{2,1,0:T(8,128)} %fusion.2511.remat6, f32[16,512]{1,0:T(8,128)} %get-tuple-element.2557, f32[16,512]{1,0:T(8,128)} %get-tuple-element.2712, f32[...
     Allocation type: HLO temp
     ==========================

  6. Size: 96.00M
     Shape: f32[16,512,3072]{2,1,0:T(8,128)}
     Unpadded size: 96.00M
     XLA label: %fusion.784.remat5 = f32[16,512,3072]{2,1,0:T(8,128)} fusion(f32[3072]{0:T(1024)} %fusion.10832, f32[16,512,768]{2,1,0:T(8,128)} %get-tuple-element.3228, f32[16,512]{1,0:T(8,128)} %get-tuple-element.2552, f32[16,512]{1,0:T(8,128)} %get-tuple-element.2707, ...
     Allocation type: HLO temp
     ==========================

  7. Size: 96.00M
     Shape: bf16[16,12,512,512]{3,2,1,0:T(8,128)(2,1)}
     Unpadded size: 96.00M
     XLA label: %fusion.114.remat5 = bf16[16,12,512,512]{3,2,1,0:T(8,128)(2,1)} fusion(f32[]{:T(256)S(6)} %divide.2671, f32[16,12,512]{2,1,0:T(8,128)} %reshape.16448, f32[16,12,512]{2,1,0:T(8,128)} %reshape.16444, f32[16,12,512,512]{3,2,1,0:T(8,128)} %copy.5019, f32[16,51...
     Allocation type: HLO temp
     ==========================

  8. Size: 96.00M
     Shape: f32[16,512,3072]{2,1,0:T(8,128)}
     Unpadded size: 96.00M
     XLA label: %fusion.782.remat7 = f32[16,512,3072]{2,1,0:T(8,128)} fusion(f32[3072]{0:T(1024)} %fusion.10826, f32[16,512,768]{2,1,0:T(8,128)} %fusion.2513.remat6, f32[16,512]{1,0:T(8,128)} %get-tuple-element.2555, f32[16,512]{1,0:T(8,128)} %get-tuple-element.2710, f32[...
     Allocation type: HLO temp
     ==========================

  9. Size: 96.00M
     Shape: bf16[16,12,512,512]{3,2,1,0:T(8,128)(2,1)}
     Unpadded size: 96.00M
     XLA label: %fusion.451 = (bf16[16,12,512,512]{3,2,1,0:T(8,128)(2,1)}, bf16[16,12,512,512]{3,2,1,0:T(8,128)(2,1)}) fusion(f32[16,12,512,512]{3,2,1,0:T(8,128)} %copy.5219, f32[]{:T(256)S(6)} %divide.2671, f32[16,12,512]{2,1,0:T(8,128)} %fusion.452, f32[]{:T(256)S(6)} %...
     Allocation type: HLO temp
     ==========================

  10. Size: 96.00M
     Shape: bf16[16,12,512,512]{3,2,1,0:T(8,128)(2,1)}
     Unpadded size: 96.00M
     XLA label: %fusion.451 = (bf16[16,12,512,512]{3,2,1,0:T(8,128)(2,1)}, bf16[16,12,512,512]{3,2,1,0:T(8,128)(2,1)}) fusion(f32[16,12,512,512]{3,2,1,0:T(8,128)} %copy.5219, f32[]{:T(256)S(6)} %divide.2671, f32[16,12,512]{2,1,0:T(8,128)} %fusion.452, f32[]{:T(256)S(6)} %...
     Allocation type: HLO temp
     ==========================

  11. Size: 96.00M
     Shape: f32[16,512,3072]{2,1,0:T(8,128)}
     Unpadded size: 96.00M
     XLA label: %fusion.773.remat7 = f32[16,512,3072]{2,1,0:T(8,128)} fusion(f32[3072]{0:T(1024)} %fusion.10799, f32[16,512,768]{2,1,0:T(8,128)} %fusion.2495.remat6, f32[16,512]{1,0:T(8,128)} %get-tuple-element.2573, f32[16,512]{1,0:T(8,128)} %get-tuple-element.2728, f32[...
     Allocation type: HLO temp
     ==========================

  12. Size: 96.00M
     Shape: f32[16,512,3072]{2,1,0:T(8,128)}
     Unpadded size: 96.00M
     XLA label: %fusion.774.remat7 = f32[16,512,3072]{2,1,0:T(8,128)} fusion(f32[3072]{0:T(1024)} %fusion.10802, f32[16,512,768]{2,1,0:T(8,128)} %fusion.2497.remat6, f32[16,512]{1,0:T(8,128)} %get-tuple-element.2571, f32[16,512]{1,0:T(8,128)} %get-tuple-element.2726, f32[...
     Allocation type: HLO temp
     ==========================

  13. Size: 96.00M
     Shape: f32[16,512,3072]{2,1,0:T(8,128)}
     Unpadded size: 96.00M
     XLA label: %fusion.775.remat7 = f32[16,512,3072]{2,1,0:T(8,128)} fusion(f32[3072]{0:T(1024)} %fusion.10805, f32[16,512,768]{2,1,0:T(8,128)} %fusion.2499.remat6, f32[16,512]{1,0:T(8,128)} %get-tuple-element.2569, f32[16,512]{1,0:T(8,128)} %get-tuple-element.2724, f32[...
     Allocation type: HLO temp
     ==========================

  14. Size: 96.00M
     Shape: f32[16,512,3072]{2,1,0:T(8,128)}
     Unpadded size: 96.00M
     XLA label: %fusion.776.remat7 = f32[16,512,3072]{2,1,0:T(8,128)} fusion(f32[3072]{0:T(1024)} %fusion.10808, f32[16,512,768]{2,1,0:T(8,128)} %fusion.2501.remat6, f32[16,512]{1,0:T(8,128)} %get-tuple-element.2567, f32[16,512]{1,0:T(8,128)} %get-tuple-element.2722, f32[...
     Allocation type: HLO temp
     ==========================

  15. Size: 96.00M
     Shape: f32[16,512,3072]{2,1,0:T(8,128)}
     Unpadded size: 96.00M
     XLA label: %fusion.777.remat7 = f32[16,512,3072]{2,1,0:T(8,128)} fusion(f32[3072]{0:T(1024)} %fusion.10811, f32[16,512,768]{2,1,0:T(8,128)} %fusion.2503.remat6, f32[16,512]{1,0:T(8,128)} %get-tuple-element.2565, f32[16,512]{1,0:T(8,128)} %get-tuple-element.2720, f32[...
     Allocation type: HLO temp
     ==========================

  16. Size: 96.00M
     Shape: f32[16,512,3072]{2,1,0:T(8,128)}
     Unpadded size: 96.00M
     XLA label: %fusion.778.remat6 = f32[16,512,3072]{2,1,0:T(8,128)} fusion(f32[3072]{0:T(1024)} %fusion.10814, f32[16,512,768]{2,1,0:T(8,128)} %fusion.2505.remat6, f32[16,512]{1,0:T(8,128)} %get-tuple-element.2563, f32[16,512]{1,0:T(8,128)} %get-tuple-element.2718, f32[...
     Allocation type: HLO temp
     ==========================

  17. Size: 96.00M
     Shape: f32[16,512,3072]{2,1,0:T(8,128)}
     Unpadded size: 96.00M
     XLA label: %fusion.779.remat7 = f32[16,512,3072]{2,1,0:T(8,128)} fusion(f32[3072]{0:T(1024)} %fusion.10817, f32[16,512,768]{2,1,0:T(8,128)} %fusion.2507.remat6, f32[16,512]{1,0:T(8,128)} %get-tuple-element.2561, f32[16,512]{1,0:T(8,128)} %get-tuple-element.2716, f32[...
     Allocation type: HLO temp
     ==========================

  18. Size: 96.00M
     Shape: f32[16,512,3072]{2,1,0:T(8,128)}
     Unpadded size: 96.00M
     XLA label: %fusion.780.remat6 = f32[16,512,3072]{2,1,0:T(8,128)} fusion(f32[3072]{0:T(1024)} %fusion.10820, f32[16,512,768]{2,1,0:T(8,128)} %fusion.2509.remat6, f32[16,512]{1,0:T(8,128)} %get-tuple-element.2559, f32[16,512]{1,0:T(8,128)} %get-tuple-element.2714, f32[...
     Allocation type: HLO temp
     ==========================

  19. Size: 96.00M
     Shape: f32[16,512,3072]{2,1,0:T(8,128)}
     Unpadded size: 96.00M
     XLA label: %fusion.783.remat6 = f32[16,512,3072]{2,1,0:T(8,128)} fusion(f32[3072]{0:T(1024)} %fusion.10829, f32[16,512,768]{2,1,0:T(8,128)} %fusion.2515.remat6, f32[16,512]{1,0:T(8,128)} %get-tuple-element.2553, f32[16,512]{1,0:T(8,128)} %get-tuple-element.2708, f32[...
     Allocation type: HLO temp
     ==========================

  20. Size: 96.00M
     Shape: bf16[16,12,512,512]{3,2,1,0:T(8,128)(2,1)}
     Unpadded size: 96.00M
     XLA label: %fusion.453 = (bf16[16,12,512,512]{3,2,1,0:T(8,128)(2,1)}, bf16[16,12,512,512]{3,2

3. Trying a single sample on Google Colab
When replacing the whole text file with a single sentence the training runs out much faster.