And here’s the solution:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "6"
import torch
from transformers import LlamaForCausalLM
model = LlamaForCausalLM.from_pretrained(
pretrained_model_name_or_path='decapoda-research/llama-7b-hf',
load_in_8bit=True,
device_map={'': 0},
)
del model
import gc
gc.collect()
torch.cuda.empty_cache()
print('breakpoint here - is memory freed?')