How do I create Torch Tensor without any wasted storage space/baggage?

From testing experience, the first Tensor push to GPU will roughly take up to 700-800 MiB of the GPU VRAM. You can then allocate more tensor to GPU without a shift in the VRAM until you have exceeded the pre-allocated space given from the first Tensor.

Example:

x = torch.tensor(1).cuda()
# ~770 MiB allocated
z = torch.zeros((1000)).cuda()
# VRAM not affected.
d = torch.zeros((1000000000)).cuda() 
# plus ~1000MiB (more accurate allocation according to bytes allocated for the datatype multiply by itemsize.

How do I create the first Tensor without 700+MiB worth of VRAM baggage? This will be especially useful for inference/deployment. Most Deep Learning applications have static inputs allocation; for example, a batch of images (ex: size (10,3,200,200) in float32) will just require 4.8MB of VRAM space. It does not make any sense in deployment to waste so much space for your model inputs.

Are there a specific way to initialise PyTorch Tensor similar to how we do it in PyCUDA/CUDA with allocate and copy? It will awesome if there is a way to allow developers to specify the nbytes of their Tensor like CUDA.

Hopefully there is already a way in PyTorch or else only the implementation of CUDA code itself can ensure CORRECT memory allocation for our inputs. PyCUDA GpuArray does not transfer well to PyTorch Tensor (the other way around works though) which leads to the need to code every single function in PyCUDA from scratch.

1 Like

No, that’s not the case. PyTorch will create the CUDA context in the very first CUDA operation, which can use ~600-1000MB of GPU memory depending on the CUDA version as well as the used device.
PyTorch itself will allocate the needed memory and will use an internal cache mechanism. You can read more about it here.

That’s not possible, since the CUDA context has to be created, which contains e.g. the device code.

1 Like

Hi @ptrblck ,

Thank you for your prompt reply. I did some further digging with thanks to your information.

I think I have a few follow up questions if you do not mind:

  1. Is there a way in torch internals to lower the VRAM necessary for the CUDA context?
  2. (link to 1) Is there a way to remove unnecessary internal caching if I am 100% sure that the only torch tensor that will be allocated is the input tensor? As expected, I always see a lot of memory reserved that is not in use.
  3. If PyCUDA has already created context, must we still create a new context for PyTorch on the same running script?

I have PyCUDA + TensorRT code in the same script with my PyTorch pre/post-processing code, the context has already been created:

import pycuda.driver as cuda
ctx = cuda.Device(0).make_context()

It seems that creating the context as shown above only takes up 200 MiB while PyTorch initialisation takes up to 770MiB. Tested on my GTX1080, CUDA11.1.

If there is no way to remove the memory caching/reserved in PyTorch Tensor then it seems that PyTorch can only be used for quick development to train and test models. Seems like the workflow has to be:

PyTorch Development → Get model weights → Convert to TensorRT → Convert all pre and post-processing to CUDA implementation (super time consuming)

  1. You could rebuild PyTorch and remove libraries shipping with device code, such as cuDNN. While this would yield a performance hit, no cuDNN kernels will be loaded and stored in the CUDA context.
  2. Yes, you can use PYTORCH_NO_CUDA_MEMORY_CACHING=1 to disable the cache.
  3. PyTorch needs to load its own kernels as well as device code from other libs (cuDNN, cublas, NCCL, etc.), which might not be the case for PyCUDA. It might be possible to share some driver-related code in the context, but I don’t know how much memory savings would be expected and haven’t experimented with it.
1 Like

Hi @ptrblck, once again, thank you for the prompt reply.

From your answers, I guess the only way is to figure out what library is the PyTorch functions specifically using (ex: for concat, stacking, NMS) and remove the rest during build (ex: cuDNN is not needed as DNN is managed by TensorRT).

Appreciate the help.


Potential future discussion for AI developers that stumbles upon this thread.

It seems that if we want to deploy using PyTorch functions (NMS, Stacking, etc) that made our lives so much easier, we have to deal with the huge memory allocated for its context.

This problem must be faced by many that want to deploy multiple modules, each leveraging PyTorch for ease of GPU programming. The context memory stacks up very quickly to fill up the entire VRAM of the GPU especially since deployment based GPU usually only has 8MiB-10MiB memory.

My personal workflow is that I use PyTorch for preprocessing and postprocessing as manipulating PyTorch Tensor is very easy. The pointer of PyTorch processed Tensor (pycudatorch.py · GitHub) can then be passed into TensorRT (optimised model), output from TensorRT will remain as a PyTorch Tensor allowing very easy postprocessing (PyTorch readily available functions) and you can also use CUDA kernels that you have written to leverage the GPU parallelism (PyCUDA) when you need more control over the implementations of individual threads that PyTorch does not provide. The obvious weakness of this workflow is the PyTorch context memory allocation, I currently cannot run more than 2 models in my GPU due to this.

I guess the only way for developers that need to deploy multiple models/modules and are constrained by the GPU memory (usually the case) is to not use PyTorch at all for deployment. Write functions you need in CUDA yourself and use TensorRT as usual.

you can try recompiling pytorch with cuda kernels runtime compiled on demand, I think config line like TORCH_CUDA_ARCH_LIST=“6.0+PTX” would do that. First run will be very slow, then they’ll be loaded from disk cache. I’m not certain that this will have a big effect though.

I’ve also heard something about “static runtime” somewhere on github, not sure if anything materialized from that.

While building PyTorch for your compute capability would speed up the build process, the CUDA context should already only load the device kernels, so I would not assume to see a reduction in memory.

@tlim besides the loaded kernels coming from libraries, you would also need to check the native PyTorch kernels, which are often templated for various dtypes etc. and would also use memory on the device. @eqy profiled the CUDA context size a while ago and could probably add more information.

1 Like

If you aren’t using any linear algebra routines that would be provided by MAGMA, compiling with USE_MAGMA=0 should also give you around 100MiB? or so of memory savings.

Of course, once you are willing to pay recompilation costs and are willing to specialize your build for certain use cases, the world is your oyster when it comes to deleting unused GPU kernels to save memory usage…
The hard limit for this would probably be about 300MiB or the CUDA context size.

And somehow you can get small (<100mb) contexts in general, but big monoliths like pytorch have huge contexts from the start.

That’s interesting, as I haven’t seen such a small context in a while. Could you point me to an example, if possible?
E.g using the matrixMul example on an A100 I see a memory usage of ~400MB for a matmul of [1, 1] x [1, 1].
Building for sm_80 or all architectures also doesn’t change the usage.

Yeah, I quoted a number for ancient cards (or old sm, not sure), 200mb is a more reasonable baseline measured above. The point was that pytorch’s context is bigger than that baseline, before kernels are needed.

@ptrblck, regarding your mention of different GPU on CUDA version producing different context memory allocation, I tested a GTX1080 and a GTX2080 (on CUDA 11.1, PyTorch 1.8.1) and it seems that the GTX1080 allocates 771 MiB while the GTX2080 allocates 1177MiB.

I am wondering whether is there some sort of general guides on CUDA, PyTorch, Driver, GPU Versioning to memory needed for context allocation. This could be a much easier thing to follow than trying out unique compilations.

At the end of the day, I think it seems that writing the processing code in CUDA is the only way to go about for deployment due to VRAM limitation.

You could try to see some theoretical requirements which would depend on the actual hardware setup, but given that each PyTorch, cuDNN, cublas, CUDA, NCCL version might add/remove kernels for a specific compute capability, the easiest way it to profile the CUDA context directly.

@ptrblck I guess that’s the only way to go about it. Thanks for the advice.