How to reverse pruning in pytorch?

I want to get the l1_unstructured pruning mask and then reverse the pruning. Is it possible to do on pytorch ?

Do you mean to say that you have only the weights file and you want to reverse it from there?

Lets say I have this implementation:

import torch
import torch.nn as nn
import torch.nn.utils.prune as prune

linear_layer = nn.Linear(10, 5)

prune.l1_unstructured(linear_layer, name="weight", amount=0.5)

mask = dict(linear_layer.named_buffers())['weight_mask']
print("Pruning mask:")
print(mask)

Now if I use prune.remove, it makes the pruning permanent (multiplies the mask with original weight and sets that as the new weight).

I don’t want that, I want to just mask the weights in the forward pass but also update the pruned weights too. (i.e gradients for pruned weights need not be 0).

Ah okay!

class MaskedLinear(nn.Module):
    def __init__(self, linear):
        super(MaskedLinear, self).__init__()
        self.linear = linear
        self.register_buffer('mask', dict(self.linear.named_buffers())['weight_mask'])

    def forward(self, x):
        # Apply the mask directly in the forward pass
        masked_weight = self.linear.weight * self.mask
        return nn.functional.linear(x, masked_weight, self.linear.bias)

Is this what you are looking for then?

Sanity Check Script:

import torch
import torch.nn as nn
import torch.nn.utils.prune as prune

# Create a single Linear layer to use in both implementations
linear_layer = nn.Linear(10, 5)

#=== Original Implementation ===
prune.l1_unstructured(linear_layer, name="weight", amount=0.5)

mask = dict(linear_layer.named_buffers())['weight_mask']

# === New Implementation ===
class MaskedLinear(nn.Module):
    def __init__(self, linear):
        super(MaskedLinear, self).__init__()
        self.linear = linear
        self.register_buffer('mask', dict(self.linear.named_buffers())['weight_mask'])

    def forward(self, x):
        # Apply the mask directly in the forward pass
        masked_weight = self.linear.weight * self.mask
        return nn.functional.linear(x, masked_weight, self.linear.bias)

masked_linear = MaskedLinear(linear_layer)

print("Original Pruning mask:")
print(mask)

print("New Pruning mask:")
print(masked_linear.mask)

# Check if the masks are equal
assert torch.equal(mask, masked_linear.mask)

print("\nOriginal Weights:")
print(linear_layer.weight)

print("\nNew Weights:")
print(masked_linear.linear.weight)

# Check if the weights are equal
assert torch.equal(linear_layer.weight, masked_linear.linear.weight)

# Comparison of forward pass
input_data = torch.randn(1, 10)

print("\n=== Forward Pass Comparison ===")
print("Original output:")
print(linear_layer(input_data))

print("\nNew output:")
print(masked_linear(input_data))

# Check if the outputs are equal
assert torch.equal(linear_layer(input_data), masked_linear(input_data))

# Effect on gradients
print("\n=== Gradient Comparison ===")
original_output = linear_layer(input_data)
original_output.sum().backward(retain_graph=True)
print("Original gradients:")
# Accessing the gradient of the original weight parameter
print(linear_layer.weight_orig.grad)

# Reset gradients
linear_layer.zero_grad()

masked_output = masked_linear(input_data)
masked_output.sum().backward()
print("\nNew gradients:")
# Accessing the gradient of the original weight parameter in MaskedLinear
print(masked_linear.linear.weight_orig.grad)