Out of RAM for computing loss with expanded tensor

Hello,
I found this error today, and I can’t understand if this is wanted or if it is a bug.

Let’s show you an example script:

import torch 
import torch.cuda
import torch.nn.functional as F

START_SHAPE = 500, 1000, 1000
END_SHAPE = 2, *START_SHAPE

def contiguous_test():
    print("memory used (MB):", torch.cuda.max_memory_allocated() / 2**20)
    A = torch.rand(START_SHAPE).cuda()
    print("memory used (MB):", torch.cuda.max_memory_allocated() / 2**20)
    B = torch.rand(START_SHAPE).cuda()
    print("memory used (MB):", torch.cuda.max_memory_allocated() / 2**20)
    # tensors are contiguous
    print(A.is_contiguous())
    print(B.is_contiguous())
    L = F.mse_loss(A, B, reduction='none')
    print("memory used (MB):", torch.cuda.max_memory_allocated() / 2**20)
    # let's allocate another array of END_SHAPE
    T = torch.rand(END_SHAPE).cuda()
    print("memory used (MB):", torch.cuda.max_memory_allocated() / 2**20)
    # now in RAM we have: 3 tensors with START_SHAPE and one tensor with END_SHAPE
    # we have enough RAM to do it!
    k = A.shape + B.shape + L.shape + T.shape # <- just to be sure that variables are not deleted

def non_contiguous_test():
    print("memory used (MB):", torch.cuda.max_memory_allocated() / 2**20)
    A = torch.rand(START_SHAPE).cuda()
    A = A.expand(END_SHAPE) # <- this doesn't uses RAM
    print("memory used (MB):", torch.cuda.max_memory_allocated() / 2**20)
    B = torch.rand(START_SHAPE).cuda()
    B = B.expand(END_SHAPE) # <- this doesn't uses RAM
    print("memory used (MB):", torch.cuda.max_memory_allocated() / 2**20)
    # now tensors aren't contiguous anymore
    print(A.is_contiguous())
    print(B.is_contiguous())
    L = F.mse_loss(A, B, reduction='none')
    print("memory used (MB):", torch.cuda.max_memory_allocated() / 2**20)
    # now in RAM we have: 2 tensors with START_SHAPE and one tensor with END_SHAPE
    # but we have not enough RAM to do it, why???
    k = A.shape + B.shape + L.shape # <- just to be sure that variables are not deleted

if __name__ == "__main__":
    import sys
    # just add the function name and we'll execute it
    eval(sys.argv[1])()

You can run python test_ram.py contiguous_test and python test_ram.py non_contiguous_test to test the script and call the two different functions.

In contiguous_test we instantiate two tensors, compute the loss between them
with the same shape and then instantiate a bigger tensor T. In
non_contiguous_test we instantiate two tensors, expand them without using RAM
and then compute the loss, that should have the same shape as T.

Resuming, here are the tensors that we are allocating:

contiguous_test non_contiguous_test shape
A A (500, 1000, 1000)
B B (500, 1000, 1000)
mse_loss(A-B) - (500, 1000, 1000)
tensor T - (2, 500, 1000, 1000)
- mse_loss(A-B) (2, 500, 1000, 1000)

Thus, in contiguous_test we are using more RAM than in non_contiguous_test. However, contiguous_test works, while non_contiguous_test not. Why? It seems that the loss is calling contiguous() on both the two tensors, but why do we need that?

$ python test_ram.py contiguous_test
memory used (MB): 0.0
memory used (MB): 1908.0
memory used (MB): 3816.0
True
True
memory used (MB): 5724.0
memory used (MB): 9538.697265625
$ nvidia-smi
Tue Jan  7 11:52:40 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 430.64       Driver Version: 430.64       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|===============================+======================+======================|
|   0  TITAN V             Off  | 00000000:01:00.0 Off |                  N/A |
| 29%   43C    P8    29W / 250W |      0MiB / 12066MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage      |
|=============================================================================|
|  No running processes found                                                 |
+-----------------------------------------------------------------------------+
$ python test_ram.py non_contiguous_test
memory used (MB): 0.0
memory used (MB): 1908.0
memory used (MB): 3816.0
False
False
Traceback (most recent call last):
  File "test_ram.py", line 46, in <module>
    eval(sys.argv[1])()
  File "test_ram.py", line 37, in non_contiguous_test
    L = F.mse_loss(A, B, reduction='none')
  File "/home/federico/.pyenv/versions/3.6.9/lib/python3.6/site-packages/torch/nn/functional.py", line 2204, in mse_loss
    ret = torch._C._nn.mse_loss(expanded_input, expanded_target, _Reduction.get_enum(reduction))
RuntimeError: CUDA out of memory. Tried to allocate 3.73 GiB (GPU 0; 11.78 GiB total capacity; 7.45 GiB already allocated; 3.27 GiB free; 1.30 MiB cached)




Edit and partial solution

After having troubled a bit, I discovered that the implementation of the
MSELoss is rather bad in PyTorch
. In general, the code used to compute it is ret = (input - target) ** 2.

Now: input - target cannot be performed in place because they are two views.
Thus, I think that PyTorch is cloning both of them. A better approach would be
output = input.clone(); output -= target to clone only one of the two
tensors.

The whole same story also applies to other losses too.

The following implementation works for me:

def myloss(A, B, reduction='none'): 
    out = A.clone()
    out -= B
    return out.pow_(2)

The following instead goes out of RAM during the difference

def myloss(A, B, reduction='none'): 
    out = A - B
    return out.pow_(2)

The following goes out of RAM during the power, because it is not in place

def myloss(A, B, reduction='none'): 
    out = A.clone()
    out -= B
    out = out**2
    return out           

Hi,

I think your analysis is correct.
The internal implementation is most likely analogous to:

def myloss(A, B, reduction='none'): 
    out = A - B
    return out.pow_(2)

Since we are not allowed to change either A or B inplace. And we want to work in all cases, this leads to the simplest implementation.

Note that the difference with the one that works for your is minimal (only allocating memory slightly differently).
In general, I would advise not to be so close to the max memory. Because, any change to the rest of your code might change the order of the allocation (and the fragementation) and this will start to fail again.

1 Like