How can I successfully fine-tune a pruned LLM?

Hello,

I am currently studying how I can successfuly fine-tune some LLMs after pruning them in a certain way, so that their performance on a set of tasks is restored as much as possible. In particular, I am pruning the following models:

  • Llama-2-7b-hf
  • Llama-2-13b-hf
  • DeepSeek-R1-Distill-Llama-8B
  • DeepSeek-R1-Distill-Qwen-14B
  • DeepSeek-R1-Distill-Qwen-32B

I am evaluating them on the tasks arc_challenge, arc_easy, hellaswag, lambada_openai, openbookqa, piqa and winogrande. I use lm-eval (GitHub - EleutherAI/lm-evaluation-harness: A framework for few-shot evaluation of language models.), which is very convenient. I prune MLP layers, that is, I either prune mlp.gate_proj, mlp.up_proj or mlp.down_proj, or a combination of them, and I do it by setting some square-blocks to zero within the weight matrix (512x512 or 256x256). I need the models to regain the lost performance, but I am having a hard time doing so. Here is the training script that I am currently using:

    lora_config = LoraConfig(
        r=16,
        lora_alpha=16,
        lora_dropout=0.1,
        target_modules=['q_proj','k_proj','v_proj','output_proj'],
    )
    model = get_peft_model(model, lora_config)
    optimizer = AdamW(model.parameters(), lr=2e-4, weight_decay=0.01)

    training_args = transformers.TrainingArguments(
        output_dir="./results",
        logging_strategy="epoch",
        eval_strategy="steps",
        eval_steps=500,
        learning_rate=1e-4,
        auto_find_batch_size=True,
        gradient_accumulation_steps=1,
        num_train_epochs=2,
        eval_on_start = True,
        bf16 = True,
        log_on_each_node=False
    )

    trainer = transformers.Trainer(
        model=model,
        args=training_args,
        train_dataset=lm_datasets["train"],
        eval_dataset=lm_datasets["validation"],
        processing_class=tokenizer,
        data_collator=data_collator,
        optimizers=(optimizer, None)
    )
    trainer.can_return_loss = True

    trainer.train()

    # Merge LoRA structure
    model = model.merge_and_unload()

I have tried training on 80000 samples of C4 dataset (sometimes 160000, but the results barely improve). I use 4 GPUs H100 64GB. The way I preprocess my data is as follows:

    raw_datasets = load_dataset(datasets_path, split=[f'train[0:{train_number_samples}]', f'validation[:{validation_number_samples}]'], cache_dir="./cache_training")
    raw_datasets = DatasetDict({
        "train": raw_datasets[0],
        "validation": raw_datasets[1],
    })

    column_names = list(raw_datasets["train"].features)
    text_column_name = "text" if "text" in column_names else column_names[0]

    def tokenize_function(examples):
        return tokenizer(examples[text_column_name])

    tokenized_datasets = raw_datasets.map(
        tokenize_function,
        batched=True,
        remove_columns=column_names,
    )

    # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size_texts.
    def group_texts(examples):
        block_size_texts = 1024
        # Concatenate all texts.
        concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
        total_length = len(concatenated_examples[list(examples.keys())[0]])
        # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
            # customize this part to your needs.
        total_length = (total_length // block_size_texts) * block_size_texts
        # Split by chunks of max_len.
        result = {
            k: [t[i : i + block_size_texts] for i in range(0, total_length, block_size_texts)]
            for k, t in concatenated_examples.items()
        }
        result["labels"] = result["input_ids"].copy()
        return result

    lm_datasets = tokenized_datasets.map(
        group_texts,
        batched=True,
    )

Does anyone know if everything looks fine here? The validation loss considerably decreases, which means that the training process is indeed doing something. I have tried removing the blocks with the lowest L2 norm, hoping that I would not harm the model as much, but I see no difference compared to randomly removing them. So I thought that I can try training without Lora (1), keeping the initial and final layers intact (2) or increase the target_modules list (3), but I am not sure if this will be effective.

Can anyone give me some better ideas?

Thank you!

Could you share a little more detail about how exactly you prune the weights? Zeroing 256x256 (or even 512x512) blocks seems pretty harsh to me. I don’t have a ton of experience with this and there is a ton of people here that have a more sophisticated answer to this so take the following with a grain of salt:
I would try to do some kind of Train-Time Pruning where you finetune your model and add a regularization term (probably something as simple as the L1/L2 norm works) to encourage a sparse weight matrix. If you do this right I think there is a chance you can use Nvidia’s 2:4 sparsity. I did this once and it really made a difference for me, however the setup was less complicated.
But maybe you tried this already?

1 Like

Hi, thanks for your quick answer! I know it looks harsh, yes. I do it this way because I can accelerate the model’s performance with some custom kernels. To be honest, most of the model’s metrics are destroyed by zeroing a 25% of the total weights (even after training), where in reality there are methods removing more than a 50% and still getting reasonable results.

I am afraid the Nvidia 2:4 pattern wouldn’t help here, since I really need the sparse-block structure, and 2:4 implies removing single elements, not blocks. However, maybe I can try your approach. Can you give me more hints on how to apply your method? Let’s say I have a list of masks (each mask corresponding to a weight-matrix) and I want to prune the elements I specify in the mentioned mask. Would I need to add a penalty to the loss function to discourage non-zeros in positions where the mask is False, or something like this?

I basically added the L1 or L2 norm of the matrix that I wanted to prune to the loss function while finetuning. Ideally this drives some of the weights close to zero. You can then define a threshold and set the weights that are “close enough” to zero to be exactly zero.
But I am afraid that won’t work for you in this case because this method doesn’t give you any guarantee on what weights will be pruned. It is hard enough to get 2:4 sparsity to work if you do it like that but I don’t think there is a chance you can prune large blocks like that. But maybe if you reduce the block size? Hope that makes sense.
Where are the layers that you want to prune in the network? If I remember correctly the first and last linear layers are way more sensitive to pruning than the “middle” ones. Maybe you can try to get rid of some blocks in the middle and keep the outer ones?

I think you have given me an idea. It’s true that your method does not guarantee which elements are pruned, but, given a mask, I could modify the loss in the following way:

l1 = torch.sum(torch.abs(layer.weight * mask))
loss += 1e-4 * l1

(I guess L1 could be L2 and 1e-4 could be adjusted)
This makes sense to me because only non-zero weights in wrong positions contribute to the loss. But in any case yeah, thanks for the ideas! I will try this, and also reducing the size of the blocks. I prune all the layers, maybe I will also change that and only prune the inner ones.

Thanks

1 Like

Sounds good, please let me know if it worked!

As Paul hit on, zeroing out entire blocks is pretty aggressive.

Additionally, I assume when you say “fine-tune”, you’re referring to applying LoRA vectors to the matrices and training those.

Consider the following:

D = A + b x c

Where A is the original matrix, b and c are vectors, x is the outer product here(normally it’s written as an x with a circle around it), and D is the final matrix. Considering the outer product operator, LoRA vectors tend to target row and column tweaking during fine tuning. So it’s not very granular. And it’s going to have a very hard time getting any granularity or traction on an entire block you’ve zeroed out in A via this pruning method.

You might have better luck combining (i.e. A + b x c + d x e) a couple LoRAs fine-tuned for each matrix, but this still may prove inferior.

What would be more compatible with LoRAs would be if you target pruning rows and columns. This will compliment Lora fine-tuning much better.

Alternatively, if you’re going with a square-block pruning method, as described, you may need to invent a fine-tuning structure that compliments this and perhaps takes advantage of kernels, as well. Perhaps applying a learnable 2d conv layer to the matrix might have some potential in giving the granularity needed and be worth exploring.

1 Like