Condensing a pruned sparse network

Hey,

I am trying to prune a network and then condense the sparse network so that it results in speedup, as a toy example I am using a single linear layer model, pruning it and then using the pruned weights to determine the new network.

To generate the new network I take a deep copy of the original network and condense it by replacing the linear layer with a sparse layer; although the approach seems to work for the toy example I am wondering if there is a simpler approach to the problem.

At the moment the only way I could think of implementing the sparse linear layer was to split the original linear layer into multiple and then go through each one of them recursively, it would be better if we could run all them parallely though.

Any advise on improving the system will be much appreciated.

Here’s the code I am using for my toy example

#!/usr/bin/python

import copy
import torch
import numpy as np

torch.manual_seed(0)

class Original(torch.nn.Module):

    def __init__(self):
        super(Original, self).__init__()
        self.l1 = torch.nn.Linear(3, 3, bias=True)
    
    def forward(self, x):
        return self.l1(x)

class SparseLinear(torch.nn.Module):

    def __init__(self, original):
        super(SparseLinear, self).__init__()
        nonzero_weight = (original.weight != 0)
        needs_bias = original.bias is not None
        self.linears = torch.nn.ModuleList()
        for i, weight in enumerate(nonzero_weight):
            capture_indices = weight.nonzero().squeeze()
            l = torch.nn.Linear(weight.sum(), 1, bias=needs_bias)
            l.weight.data = original.weight[i, capture_indices].view(-1)
            if needs_bias:
                l.bias.data = original.bias[i]
            l.register_buffer('weight_mask', weight)
            l.register_buffer('capture_indices', capture_indices)
            self.linears.append(l)

    def forward(self, x):
        y = []
        for linear in self.linears:
            capture_indices = linear._buffers['capture_indices']
            _x = x[capture_indices].view(-1)
            _y = linear(_x)
            y += [_y]
        return torch.stack(y)

def prune(model):
    k = 20
    all_weights = []
    for p in model.parameters():
        if len(p.data.size()) != 1:
            all_weights += list(p.cpu().data.abs().numpy().flatten())
    threshold = np.percentile(np.array(all_weights), k)

    for p in model.parameters():
        if len(p.data.size()) != 1:
            mask = p.data.abs() > threshold
            mask = torch.autograd.Variable(mask, requires_grad=False, volatile=False)
            p.data = p.data * mask.data.float()

def prune2dense(model):
    for name, module in model.named_modules():
        if hasattr(module, 'weight'):
            setattr(model, name, SparseLinear(module))

def countparams(model):
    total = sum(p.numel() for p in model.parameters())
    withgrad = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Parameters:: total: {total}, withgrad: {withgrad}")

def print_state_dict(model):
    dic = model.state_dict()
    print("\tstate_dict:: ", dic)
    print()

original = Original()
original.eval()
x = torch.FloatTensor([1, 2, 3])

# print_state_dict(original)
countparams(original)
print("y::", original(x))

prune(original)
# print_state_dict(original)
countparams(original)
print("y::", original(x))

condensed = copy.deepcopy(original)
prune2dense(condensed)
condensed.eval()

# print_state_dict(condensed)
countparams(condensed)
print("y::", condensed(x))

This results in the following output:

Parameters:: total: 9, withgrad: 9
y:: tensor([-0.8104, -0.4052, 0.7504], grad_fn=<SqueezeBackward3>)
Parameters:: total: 9, withgrad: 9
y:: tensor([-0.8061, -0.4052, 0.7618], grad_fn=<SqueezeBackward3>)
Parameters:: total: 7, withgrad: 7
y:: tensor([-0.8061, -0.4052, 0.7618], grad_fn=<StackBackward>)