Model being occupying 2.5x the size - loading and running

I am trying to load and run the meta-llama/Meta-Llama-3-8B on a 16 GB Nvidia RTX 4060 Ti.

I am using the following BitsAndBytesConfig to load the model in 4-bit and use torch.bfloat16 for compute.

bnb_config = BitsAndBytesConfig(
    # load_in_8bit=True,
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

model = AutoModel.from_pretrained(
    'meta-llama/Meta-Llama-3-8B',
    quantization_config=bnb_config,
    device_map='auto'
)

Just loading the model, the nvidia-smi output shows almost about 4.5 GB of GPU memory being used.

Then when I run the following command, the GPU memory immediately shoots up to 15.5 GB.

inp_sample = next(iter(train_loader))
logits = model(ex_sample['input_ids'])

I think I am missing something but just running one sample of size [2, 512] through this Meta-Llama-3-8B increases the memory by this much.

Added a couple of screenshots for reference. Would love any pointers on this one.
Just Loading the model

I want to try fine tuning the model after but that seems impossible now if the an 8B model 4-bit model runs a 16GB GPU out of memory.

Adding the second screenshot here because the help forum doesn’t allow new users have more than one embedded items in the post.

After running one inference step through the model

Besides the model parameters and buffers, the model execution will also use memory to store intermediate forward activations needed to compute the gradients. You could use e.g. this until. to estimate the memory usage.

If you don’t want to compute the gradients at all, you should wrap the forward pass into a torch.no_grad() context, if that’s not already done.

1 Like