Pruning `torch.nn.MultiheadAttention` causes RuntimeError

I am running into the following RuntimeError when pruning parameters of torch.nn.MultiheadAttention module:

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

Here is the small code I’m running.

import torch.nn as nn
import torch
import torch.nn.utils.prune as prune

class net(nn.Module):
    def __init__(self):
        super().__init__()

        self.att = nn.MultiheadAttention(
            embed_dim=100,
            num_heads=2,
            batch_first=True
        )
        self.proj = nn.Linear(100, 10)
        self.criterion = nn.CrossEntropyLoss()

    def forward(self, x):
        attn_output, attn_weights = self.att(x, x, x)
        logits = self.proj(attn_output)
        return logits

    def compute_loss(self, logits, target):
        loss = self.criterion(logits, target)
        return loss

model = net().cuda()
# prune here
params_to_prune = [
    (model.att, "in_proj_weight"),
    (model.att.out_proj, "weight"),
    (model.proj, "weight")
]
prune.global_unstructured(
    params_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.05,
)

opt = torch.optim.SGD(model.parameters(), lr=0.01)
for step in range(0, 10):
    x = torch.rand((16, 50, 100)).cuda()  # data of batchsize=16, seqlength=50, featuredim=100
    y = torch.ones((16*50, )).long().cuda()
    opt.zero_grad()
    logits = model(x)
    loss = model.compute_loss(logits.reshape(16*50, 10), y)
    loss.backward()
    opt.step()
    print("step {} successful".format(step))

It seems that the error is brought by the pruning of this parameter intorch.nn.MultiheadAttention: (model.att.out_proj, "weight"). If I remove this line then there is no such RuntimeError. Am I missing anything here for pruning the MultiheadAttention module? How should I properly prune the parameters in out_proj (which I believe is just a Linear layer)? Thanks!