I’m looking for a method to sparsify a simple network as described below:
model = torch.nn.Sequential(
collections.OrderedDict(
[
("layer1", torch.nn.Linear(num_A, num_A)),
("act1", torch.nn.Tanh()),
("layer2", torch.nn.Linear(num_A, num_B)),
("act2", torch.nn.Tanh()),
("layer3", torch.nn.Linear(num_B, num_B)),
]
)
)
I am using the torch.nn.utils.prune.custom_from_mask
to prune the weights I want to be zero by sending a matrix to the device that is 99% zeros with 1% ones.
matrix = matrix.to(device)
module1 = model.layer1
module1 = module1.to(device)
torch.nn.utils.prune.custom_from_mask(module1, name='weight', mask=matrix)
The model size after masking is still as large as a fully connected network. I believe that the results I’m getting are good for the model and this strategy, but I need to make the network sparse to deploy this at scale due to memory limitations of making the matrix dense (on the CPU & GPU) and building the layer fully connected (on the GPU).
Any suggestions would be appreciated.