I want to get the l1_unstructured pruning mask and then reverse the pruning. Is it possible to do on pytorch ?
Do you mean to say that you have only the weights file and you want to reverse it from there?
Lets say I have this implementation:
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
linear_layer = nn.Linear(10, 5)
prune.l1_unstructured(linear_layer, name="weight", amount=0.5)
mask = dict(linear_layer.named_buffers())['weight_mask']
print("Pruning mask:")
print(mask)
Now if I use prune.remove, it makes the pruning permanent (multiplies the mask with original weight and sets that as the new weight).
I don’t want that, I want to just mask the weights in the forward pass but also update the pruned weights too. (i.e gradients for pruned weights need not be 0).
Ah okay!
class MaskedLinear(nn.Module):
def __init__(self, linear):
super(MaskedLinear, self).__init__()
self.linear = linear
self.register_buffer('mask', dict(self.linear.named_buffers())['weight_mask'])
def forward(self, x):
# Apply the mask directly in the forward pass
masked_weight = self.linear.weight * self.mask
return nn.functional.linear(x, masked_weight, self.linear.bias)
Is this what you are looking for then?
Sanity Check Script:
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
# Create a single Linear layer to use in both implementations
linear_layer = nn.Linear(10, 5)
#=== Original Implementation ===
prune.l1_unstructured(linear_layer, name="weight", amount=0.5)
mask = dict(linear_layer.named_buffers())['weight_mask']
# === New Implementation ===
class MaskedLinear(nn.Module):
def __init__(self, linear):
super(MaskedLinear, self).__init__()
self.linear = linear
self.register_buffer('mask', dict(self.linear.named_buffers())['weight_mask'])
def forward(self, x):
# Apply the mask directly in the forward pass
masked_weight = self.linear.weight * self.mask
return nn.functional.linear(x, masked_weight, self.linear.bias)
masked_linear = MaskedLinear(linear_layer)
print("Original Pruning mask:")
print(mask)
print("New Pruning mask:")
print(masked_linear.mask)
# Check if the masks are equal
assert torch.equal(mask, masked_linear.mask)
print("\nOriginal Weights:")
print(linear_layer.weight)
print("\nNew Weights:")
print(masked_linear.linear.weight)
# Check if the weights are equal
assert torch.equal(linear_layer.weight, masked_linear.linear.weight)
# Comparison of forward pass
input_data = torch.randn(1, 10)
print("\n=== Forward Pass Comparison ===")
print("Original output:")
print(linear_layer(input_data))
print("\nNew output:")
print(masked_linear(input_data))
# Check if the outputs are equal
assert torch.equal(linear_layer(input_data), masked_linear(input_data))
# Effect on gradients
print("\n=== Gradient Comparison ===")
original_output = linear_layer(input_data)
original_output.sum().backward(retain_graph=True)
print("Original gradients:")
# Accessing the gradient of the original weight parameter
print(linear_layer.weight_orig.grad)
# Reset gradients
linear_layer.zero_grad()
masked_output = masked_linear(input_data)
masked_output.sum().backward()
print("\nNew gradients:")
# Accessing the gradient of the original weight parameter in MaskedLinear
print(masked_linear.linear.weight_orig.grad)