Pruning `torch.nn.MultiheadAttention` causes RuntimeError

I am running into the following RuntimeError when pruning parameters of torch.nn.MultiheadAttention module:

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

Here is the small code Iā€™m running.

import torch.nn as nn
import torch
import torch.nn.utils.prune as prune

class net(nn.Module):
    def __init__(self):
        super().__init__()

        self.att = nn.MultiheadAttention(
            embed_dim=100,
            num_heads=2,
            batch_first=True
        )
        self.proj = nn.Linear(100, 10)
        self.criterion = nn.CrossEntropyLoss()

    def forward(self, x):
        attn_output, attn_weights = self.att(x, x, x)
        logits = self.proj(attn_output)
        return logits

    def compute_loss(self, logits, target):
        loss = self.criterion(logits, target)
        return loss

model = net().cuda()
# prune here
params_to_prune = [
    (model.att, "in_proj_weight"),
    (model.att.out_proj, "weight"),
    (model.proj, "weight")
]
prune.global_unstructured(
    params_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.05,
)

opt = torch.optim.SGD(model.parameters(), lr=0.01)
for step in range(0, 10):
    x = torch.rand((16, 50, 100)).cuda()  # data of batchsize=16, seqlength=50, featuredim=100
    y = torch.ones((16*50, )).long().cuda()
    opt.zero_grad()
    logits = model(x)
    loss = model.compute_loss(logits.reshape(16*50, 10), y)
    loss.backward()
    opt.step()
    print("step {} successful".format(step))

It seems that the error is brought by the pruning of this parameter intorch.nn.MultiheadAttention: (model.att.out_proj, "weight"). If I remove this line then there is no such RuntimeError. Am I missing anything here for pruning the MultiheadAttention module? How should I properly prune the parameters in out_proj (which I believe is just a Linear layer)? Thanks!

1 Like

I am running into the same error. Is there any advice on how to prune such a module?