I’m looking for a method to sparsify a simple network as described below:
model = torch.nn.Sequential( collections.OrderedDict( [ ("layer1", torch.nn.Linear(num_A, num_A)), ("act1", torch.nn.Tanh()), ("layer2", torch.nn.Linear(num_A, num_B)), ("act2", torch.nn.Tanh()), ("layer3", torch.nn.Linear(num_B, num_B)), ] ) )
I am using the
torch.nn.utils.prune.custom_from_mask to prune the weights I want to be zero by sending a matrix to the device that is 99% zeros with 1% ones.
matrix = matrix.to(device) module1 = model.layer1 module1 = module1.to(device) torch.nn.utils.prune.custom_from_mask(module1, name='weight', mask=matrix)
The model size after masking is still as large as a fully connected network. I believe that the results I’m getting are good for the model and this strategy, but I need to make the network sparse to deploy this at scale due to memory limitations of making the matrix dense (on the CPU & GPU) and building the layer fully connected (on the GPU).
Any suggestions would be appreciated.