Funny issue I ran into. If you make a dtype=torch.float16 tensor on the gpu that takes up almost all of the available space, you will actually run out of memory when trying to print. This isn’t really an issue per se, but I was curious if anyone knew what was going on in the str method that is causing this, and maybe if this can be fixed.
Could you post an executable code snippet to reproduce this issue as well as the output of
python -m torch.utils.collect_env, please?
Here is the output of the command:
Collecting environment information... PyTorch version: 1.9.0+cu102 Is debug build: False CUDA used to build PyTorch: 10.2 ROCM used to build PyTorch: N/A OS: Microsoft Windows 10 Home GCC version: Could not collect Clang version: Could not collect CMake version: version 3.19.3 Libc version: N/A Python version: 3.9 (64-bit runtime) Python platform: Windows-10-10.0.19041-SP0 Is CUDA available: True CUDA runtime version: 11.2.67 GPU models and configuration: GPU 0: GeForce RTX 2080 Ti Nvidia driver version: 461.92 cuDNN version: Could not collect HIP runtime version: N/A MIOpen runtime version: N/A Versions of relevant libraries: [pip3] numpy==1.20.3 [pip3] torch==1.9.0+cu102 [pip3] torchaudio==0.9.0 [pip3] torchvision==0.10.0+cu102 [conda] Could not collect
And here is a snippet which could do the job. If it does not, you might have to do some fiddling to find the magic value for val - of course it depends on how much memory you have available on your GPU. You will have found the issue when the stack trace contains the actual print statement.
import torch device = torch.device("cuda") val = 17 while 1: x = torch.ones((int(val), ), dtype=torch.float16, device=device) print(x) val *= 1.1
I think it’s expected that you would run out of memory, if you are using an infinite loop and allocate larger tensors in each iteration.
val is increased by a factor of
1.1, so that the tensor size for the first iterations would be:
torch.Size() torch.Size() torch.Size() torch.Size() torch.Size() torch.Size() torch.Size() torch.Size() torch.Size() torch.Size()
Since the loop is never terminating you are eventually running out of memory.
It looks like
torch.float16 is a special case (distinct from
float32) because the data is promoted (!) to
float32 for printing (as printing is CPU side logic that happens in
float32. This naturally causes PyTorch to try and grab [twice the current tensor]'s worth of GPU memory which is why you get the OOM on the call to
For the sake of completeness: as discussed offline it seems to come from this line of code and it might be worth checking, if a
printoptions flag could be added to lower the memory footprint and pay with a slower
Maybe I am not seeing the full picture, but printing a tensor usually reports very few values in the tensor. A solution could be to simply gather the values we intend to print, and only cast those, using very little memory and still casting on gpu.
Just saw this, sorry. This code is specifically meant to induce out of memory on print. You can replicate with a single alloc and print, but I don’t know a priori how much memory you have on your gpu.