LLAMA : Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0! (when checking argument for argument index in method wrapper_CUDA__index_select)

Hi. I built an sentiment classification model using LLAMA. (just adding an output layer that output 3 logits) Since llama is a huge model, I quantized it using the code snippet below.

def create_quantization_config():
            #"""
            #Quantization Config Generator
            #This function will create a confuguration for 4-bit quantization.
            #"""
            bnb_config=BitsAndBytesConfig(
                load_in_8bit=True, 
                bnb_4bit_use_double_quant=True,
                bnb_4bit_quant_type='fp4',
                bnb_4bit_compute_dtype=torch.bfloat16
                )
            return bnb_config

With the quantized model, I trained my model with my trainer class as below.

    def _run_batch(self,src:list,tgt:Tensor)->float:
        # Running each batch
        self.optimizer.zero_grad() 
        src[0]=src[0].to(self.gpu_id)
        src[1]=src[1].to(self.gpu_id)
        tgt=tgt.to(self.gpu_id)
        self.model=self.model.to(self.gpu_id)
        out=self.model(src[0].to(self.gpu_id),src[1].to(self.gpu_id)).to(self.gpu_id)
        loss=self.criterion(out,tgt)
        loss.backward()
        self.optimizer.step()

        self.train_acc.update(out,tgt)
        self.val_acc.update(out,tgt)
        return loss.item()

However, when I try to fine tune llama on 2 GPUs, I am getting

in embedding
    return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0! (when checking argument for argument index in method wrapper_CUDA__index_select)

This is interesting because when I did the exact same thing with BERT and GPT2, there was no problem. I even trained those model with 4 GPUs(single node), but no issues. It may have something to do with the quantization I think, but I am not sure. If anybody have any idea?

Thanks in advance.