Cross-post from Stack Overflow
I have written a simple Python script that uses the HuggingFace transformers
library along with torch
to run Llama3.1-8B-instruct
purely for inference, after feeding in some long-ish bits of text (about 10k-20k tokens). It runs fine on my laptop, which has a GPU with 12GB RAM, but can also access up to 28GB total (I guess from the main system RAM?)
However, when I run the code on a “Standard NC4as T4 v3” Windows Virtual Machine, with a single Tesla T4 GPU with 16GB RAM, it very quickly throws this error: CUDA out of memory. Tried to allocate XXX GiB. GPU
According to my calculations, this code should run fine given the available RAM. Nevertheless, I’ve tried to make the script more memory efficient:
- Changing the attention mechanism - by setting the
attn_implementation
when instantiating the model - first to “sdpa”. Then tried to adopt flash attention, but found it impossible to install the packageflashattn
on Windows - Using
xformers
to runenable_xformers_memory_efficient_attention
- think this is also about the attention mechanism, but couldn’t install / run this - Using
torch.inference_mode()
- Setting flags like
low_cpu_mem_usage = True
during model instantiation. - Explicitly setting
torch_dtype
totorch.float16
ortorch.bfloat16
during model instantiation. - Using
BitsAndBytesConfig
to trigger 8-bit and 4-bit quantization.
The last step was the only one to have any effect - the model manages to generate one or two responses, but still fails to complete the loop, and still fails on the very first inference. If I’ve understood the documentation properly, 4-bit quantization should drastically reduce memory requirements and make it almost trivial to do a forward pass with this model size with this GPU.
So I’m suspicious about trying to further optimise the code - the next step apparently being to manually map all the model layers to the GPU and CPU respectively by defining a device_map
.
So my questions are:
- Is this likely to be truly a memory issue, or is it a red herring?
- If it’s memory, what could I do beyond what I’ve tried?
- If it’s not, are there obvious things to check? (I’m suspecting to do with system configuration or package installations…)