Pytorch reports out of memory on print

Funny issue I ran into. If you make a dtype=torch.float16 tensor on the gpu that takes up almost all of the available space, you will actually run out of memory when trying to print. This isn’t really an issue per se, but I was curious if anyone knew what was going on in the str method that is causing this, and maybe if this can be fixed.

Could you post an executable code snippet to reproduce this issue as well as the output of python -m torch.utils.collect_env, please?

Here is the output of the command:

Collecting environment information...
PyTorch version: 1.9.0+cu102
Is debug build: False
CUDA used to build PyTorch: 10.2
ROCM used to build PyTorch: N/A

OS: Microsoft Windows 10 Home
GCC version: Could not collect
Clang version: Could not collect
CMake version: version 3.19.3
Libc version: N/A

Python version: 3.9 (64-bit runtime)
Python platform: Windows-10-10.0.19041-SP0
Is CUDA available: True
CUDA runtime version: 11.2.67
GPU models and configuration: GPU 0: GeForce RTX 2080 Ti
Nvidia driver version: 461.92
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] numpy==1.20.3
[pip3] torch==1.9.0+cu102
[pip3] torchaudio==0.9.0
[pip3] torchvision==0.10.0+cu102
[conda] Could not collect

And here is a snippet which could do the job. If it does not, you might have to do some fiddling to find the magic value for val - of course it depends on how much memory you have available on your GPU. You will have found the issue when the stack trace contains the actual print statement.

import torch

device = torch.device("cuda")
val = 17
while 1:
    x = torch.ones((int(val), ), dtype=torch.float16, device=device)
    val *= 1.1

I think it’s expected that you would run out of memory, if you are using an infinite loop and allocate larger tensors in each iteration.
val is increased by a factor of 1.1, so that the tensor size for the first iterations would be:


Since the loop is never terminating you are eventually running out of memory.

It looks like torch.half/torch.float16 is a special case (distinct from float32) because the data is promoted (!) to float32 for printing (as printing is CPU side logic that happens in float32. This naturally causes PyTorch to try and grab [twice the current tensor]'s worth of GPU memory which is why you get the OOM on the call to print. It might be interesting to consider the speed (casting on GPU is probably faster) vs. the memory footprint tradeoff.

1 Like

For the sake of completeness: as discussed offline it seems to come from this line of code and it might be worth checking, if a printoptions flag could be added to lower the memory footprint and pay with a slower print operation.

1 Like

Maybe I am not seeing the full picture, but printing a tensor usually reports very few values in the tensor. A solution could be to simply gather the values we intend to print, and only cast those, using very little memory and still casting on gpu.

Just saw this, sorry. This code is specifically meant to induce out of memory on print. You can replicate with a single alloc and print, but I don’t know a priori how much memory you have on your gpu.