GPU RAM running out on minimal model example

TLDR;
It seems as if the problem is not the amount of input the model receives, but the batches or model parameters themselves. Maybe the pretrained model is too heavy?

1. Trying a single sample locally
The first thing I tried was modelling a single sentence on my local IDE, as before but less data. I found the line that floods the RAM. Obviously enough it is the call of the model itself. When loading the language model RAM hovers at ~1775 Mb, later when the training loop gets called its at ~3700 Mb and throws the Exception at this line of code in the transformers library, an excerpt:

        # Take the dot product between "query" and "key" to get the raw attention scores.
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))

The exception itself is being thrown at the training loop in my code at

outputs = model(input_ids, attention_mask=attention_mask, labels=labels)

which is to be expected.

2. Trying the original amount with TPU in Google Colab
As @ptrblck suggested I also tried this with a TPU in Google Colab. This implementation uses the original text file of 800 lines. It hovers at around 1300 Mb RAM, slightly less than when run locally. Tensors are being trained the first epoch up until 6% but then the TPU runs out of memory as well. As it says in the error log:

RuntimeError: RESOURCE_EXHAUSTED: From /job:tpu_worker/replica:0/task:0:
2 root error(s) found.
  (0) RESOURCE_EXHAUSTED: Ran out of memory in memory space hbm. Used 8.19G of 7.48G hbm. Exceeded hbm capacity by 721.92M.

Total hbm usage >= 8.70G:
    reserved        530.00M 
    program           7.69G 
    arguments       505.21M 

This is the full log:

  0%|          | 0/62 [00:00<?, ?it/s]/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:12: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  if sys.path[0] == '':
Epoch 0:   6%|▋         | 4/62 [05:19<1:17:13, 79.88s/it, loss=17.2]
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-16-e02a8c5faa6b> in <module>()
     14 
     15         loop.set_description(f'Epoch {epoch}')
---> 16         loop.set_postfix(loss=loss.item())
     17 
     18 db = True

RuntimeError: RESOURCE_EXHAUSTED: From /job:tpu_worker/replica:0/task:0:
2 root error(s) found.
  (0) RESOURCE_EXHAUSTED: Ran out of memory in memory space hbm. Used 8.19G of 7.48G hbm. Exceeded hbm capacity by 721.92M.

Total hbm usage >= 8.70G:
    reserved        530.00M 
    program           7.69G 
    arguments       505.21M 

Output size 5.0K; shares 0B with arguments.

Program hbm requirement 7.69G:
    global             4.0K
    HLO temp          6.81G (100.0% utilization: Unpadded (6.24G) Padded (6.24G), 8.3% fragmentation (580.78M))
    overlays        905.37M

  Largest program allocations in hbm:

  1. Size: 905.37M
     XLA label: overlays
     Allocation type: overlays
     ==========================

  2. Size: 192.00M
     Shape: f32[16,12,512,512]{3,2,1,0:T(8,128)}
     Unpadded size: 192.00M
     XLA label: %copy.5019 = f32[16,12,512,512]{3,2,1,0:T(8,128)} copy(f32[16,12,512,512]{2,3,1,0:T(8,128)} %bitcast.9613)
     Allocation type: HLO temp
     ==========================

  3. Size: 192.00M
     Shape: f32[16,12,512,512]{3,2,1,0:T(8,128)}
     Unpadded size: 192.00M
     XLA label: %copy.5028 = f32[16,12,512,512]{3,2,1,0:T(8,128)} copy(f32[16,12,512,512]{2,3,1,0:T(8,128)} %bitcast.9667)
     Allocation type: HLO temp
     ==========================

  4. Size: 192.00M
     Shape: f32[16,12,512,512]{3,2,1,0:T(8,128)}
     Unpadded size: 192.00M
     XLA label: %copy.5219 = f32[16,12,512,512]{3,2,1,0:T(8,128)} copy(f32[16,12,512,512]{2,3,1,0:T(8,128)} %bitcast.2451)
     Allocation type: HLO temp
     ==========================

  5. Size: 96.00M
     Shape: f32[16,512,3072]{2,1,0:T(8,128)}
     Unpadded size: 96.00M
     XLA label: %fusion.781.remat6 = f32[16,512,3072]{2,1,0:T(8,128)} fusion(f32[3072]{0:T(1024)} %fusion.10823, f32[16,512,768]{2,1,0:T(8,128)} %fusion.2511.remat6, f32[16,512]{1,0:T(8,128)} %get-tuple-element.2557, f32[16,512]{1,0:T(8,128)} %get-tuple-element.2712, f32[...
     Allocation type: HLO temp
     ==========================

  6. Size: 96.00M
     Shape: f32[16,512,3072]{2,1,0:T(8,128)}
     Unpadded size: 96.00M
     XLA label: %fusion.784.remat5 = f32[16,512,3072]{2,1,0:T(8,128)} fusion(f32[3072]{0:T(1024)} %fusion.10832, f32[16,512,768]{2,1,0:T(8,128)} %get-tuple-element.3228, f32[16,512]{1,0:T(8,128)} %get-tuple-element.2552, f32[16,512]{1,0:T(8,128)} %get-tuple-element.2707, ...
     Allocation type: HLO temp
     ==========================

  7. Size: 96.00M
     Shape: bf16[16,12,512,512]{3,2,1,0:T(8,128)(2,1)}
     Unpadded size: 96.00M
     XLA label: %fusion.114.remat5 = bf16[16,12,512,512]{3,2,1,0:T(8,128)(2,1)} fusion(f32[]{:T(256)S(6)} %divide.2671, f32[16,12,512]{2,1,0:T(8,128)} %reshape.16448, f32[16,12,512]{2,1,0:T(8,128)} %reshape.16444, f32[16,12,512,512]{3,2,1,0:T(8,128)} %copy.5019, f32[16,51...
     Allocation type: HLO temp
     ==========================

  8. Size: 96.00M
     Shape: f32[16,512,3072]{2,1,0:T(8,128)}
     Unpadded size: 96.00M
     XLA label: %fusion.782.remat7 = f32[16,512,3072]{2,1,0:T(8,128)} fusion(f32[3072]{0:T(1024)} %fusion.10826, f32[16,512,768]{2,1,0:T(8,128)} %fusion.2513.remat6, f32[16,512]{1,0:T(8,128)} %get-tuple-element.2555, f32[16,512]{1,0:T(8,128)} %get-tuple-element.2710, f32[...
     Allocation type: HLO temp
     ==========================

  9. Size: 96.00M
     Shape: bf16[16,12,512,512]{3,2,1,0:T(8,128)(2,1)}
     Unpadded size: 96.00M
     XLA label: %fusion.451 = (bf16[16,12,512,512]{3,2,1,0:T(8,128)(2,1)}, bf16[16,12,512,512]{3,2,1,0:T(8,128)(2,1)}) fusion(f32[16,12,512,512]{3,2,1,0:T(8,128)} %copy.5219, f32[]{:T(256)S(6)} %divide.2671, f32[16,12,512]{2,1,0:T(8,128)} %fusion.452, f32[]{:T(256)S(6)} %...
     Allocation type: HLO temp
     ==========================

  10. Size: 96.00M
     Shape: bf16[16,12,512,512]{3,2,1,0:T(8,128)(2,1)}
     Unpadded size: 96.00M
     XLA label: %fusion.451 = (bf16[16,12,512,512]{3,2,1,0:T(8,128)(2,1)}, bf16[16,12,512,512]{3,2,1,0:T(8,128)(2,1)}) fusion(f32[16,12,512,512]{3,2,1,0:T(8,128)} %copy.5219, f32[]{:T(256)S(6)} %divide.2671, f32[16,12,512]{2,1,0:T(8,128)} %fusion.452, f32[]{:T(256)S(6)} %...
     Allocation type: HLO temp
     ==========================

  11. Size: 96.00M
     Shape: f32[16,512,3072]{2,1,0:T(8,128)}
     Unpadded size: 96.00M
     XLA label: %fusion.773.remat7 = f32[16,512,3072]{2,1,0:T(8,128)} fusion(f32[3072]{0:T(1024)} %fusion.10799, f32[16,512,768]{2,1,0:T(8,128)} %fusion.2495.remat6, f32[16,512]{1,0:T(8,128)} %get-tuple-element.2573, f32[16,512]{1,0:T(8,128)} %get-tuple-element.2728, f32[...
     Allocation type: HLO temp
     ==========================

  12. Size: 96.00M
     Shape: f32[16,512,3072]{2,1,0:T(8,128)}
     Unpadded size: 96.00M
     XLA label: %fusion.774.remat7 = f32[16,512,3072]{2,1,0:T(8,128)} fusion(f32[3072]{0:T(1024)} %fusion.10802, f32[16,512,768]{2,1,0:T(8,128)} %fusion.2497.remat6, f32[16,512]{1,0:T(8,128)} %get-tuple-element.2571, f32[16,512]{1,0:T(8,128)} %get-tuple-element.2726, f32[...
     Allocation type: HLO temp
     ==========================

  13. Size: 96.00M
     Shape: f32[16,512,3072]{2,1,0:T(8,128)}
     Unpadded size: 96.00M
     XLA label: %fusion.775.remat7 = f32[16,512,3072]{2,1,0:T(8,128)} fusion(f32[3072]{0:T(1024)} %fusion.10805, f32[16,512,768]{2,1,0:T(8,128)} %fusion.2499.remat6, f32[16,512]{1,0:T(8,128)} %get-tuple-element.2569, f32[16,512]{1,0:T(8,128)} %get-tuple-element.2724, f32[...
     Allocation type: HLO temp
     ==========================

  14. Size: 96.00M
     Shape: f32[16,512,3072]{2,1,0:T(8,128)}
     Unpadded size: 96.00M
     XLA label: %fusion.776.remat7 = f32[16,512,3072]{2,1,0:T(8,128)} fusion(f32[3072]{0:T(1024)} %fusion.10808, f32[16,512,768]{2,1,0:T(8,128)} %fusion.2501.remat6, f32[16,512]{1,0:T(8,128)} %get-tuple-element.2567, f32[16,512]{1,0:T(8,128)} %get-tuple-element.2722, f32[...
     Allocation type: HLO temp
     ==========================

  15. Size: 96.00M
     Shape: f32[16,512,3072]{2,1,0:T(8,128)}
     Unpadded size: 96.00M
     XLA label: %fusion.777.remat7 = f32[16,512,3072]{2,1,0:T(8,128)} fusion(f32[3072]{0:T(1024)} %fusion.10811, f32[16,512,768]{2,1,0:T(8,128)} %fusion.2503.remat6, f32[16,512]{1,0:T(8,128)} %get-tuple-element.2565, f32[16,512]{1,0:T(8,128)} %get-tuple-element.2720, f32[...
     Allocation type: HLO temp
     ==========================

  16. Size: 96.00M
     Shape: f32[16,512,3072]{2,1,0:T(8,128)}
     Unpadded size: 96.00M
     XLA label: %fusion.778.remat6 = f32[16,512,3072]{2,1,0:T(8,128)} fusion(f32[3072]{0:T(1024)} %fusion.10814, f32[16,512,768]{2,1,0:T(8,128)} %fusion.2505.remat6, f32[16,512]{1,0:T(8,128)} %get-tuple-element.2563, f32[16,512]{1,0:T(8,128)} %get-tuple-element.2718, f32[...
     Allocation type: HLO temp
     ==========================

  17. Size: 96.00M
     Shape: f32[16,512,3072]{2,1,0:T(8,128)}
     Unpadded size: 96.00M
     XLA label: %fusion.779.remat7 = f32[16,512,3072]{2,1,0:T(8,128)} fusion(f32[3072]{0:T(1024)} %fusion.10817, f32[16,512,768]{2,1,0:T(8,128)} %fusion.2507.remat6, f32[16,512]{1,0:T(8,128)} %get-tuple-element.2561, f32[16,512]{1,0:T(8,128)} %get-tuple-element.2716, f32[...
     Allocation type: HLO temp
     ==========================

  18. Size: 96.00M
     Shape: f32[16,512,3072]{2,1,0:T(8,128)}
     Unpadded size: 96.00M
     XLA label: %fusion.780.remat6 = f32[16,512,3072]{2,1,0:T(8,128)} fusion(f32[3072]{0:T(1024)} %fusion.10820, f32[16,512,768]{2,1,0:T(8,128)} %fusion.2509.remat6, f32[16,512]{1,0:T(8,128)} %get-tuple-element.2559, f32[16,512]{1,0:T(8,128)} %get-tuple-element.2714, f32[...
     Allocation type: HLO temp
     ==========================

  19. Size: 96.00M
     Shape: f32[16,512,3072]{2,1,0:T(8,128)}
     Unpadded size: 96.00M
     XLA label: %fusion.783.remat6 = f32[16,512,3072]{2,1,0:T(8,128)} fusion(f32[3072]{0:T(1024)} %fusion.10829, f32[16,512,768]{2,1,0:T(8,128)} %fusion.2515.remat6, f32[16,512]{1,0:T(8,128)} %get-tuple-element.2553, f32[16,512]{1,0:T(8,128)} %get-tuple-element.2708, f32[...
     Allocation type: HLO temp
     ==========================

  20. Size: 96.00M
     Shape: bf16[16,12,512,512]{3,2,1,0:T(8,128)(2,1)}
     Unpadded size: 96.00M
     XLA label: %fusion.453 = (bf16[16,12,512,512]{3,2,1,0:T(8,128)(2,1)}, bf16[16,12,512,512]{3,2

3. Trying a single sample on Google Colab
When replacing the whole text file with a single sentence the training runs out much faster.