Hi all,
This is my first question, so I will try to be not too clumsy. Besides, I am more or less a beginner with PyTorch, so I hope my questions are not that evident. And thanks for all your help beforehand!
Well, I just trained a simple model with 3 linear layers and everything went smooth and fine. Now, I want to replicate this model in order to allow some adversarial training of a few input samples at the same time, and for that I will be using the hack Goodfellow proposed here.
The problem is that the model may end up being quite big. But since many weights, all the off diagonal blocks, are going to be equal to zero, I thought: ok, let’s prune them, I don’t need them for anything. Then, the size of these repeated model will just scale with the number of replicas of the model, and not quadratically. However, I just put together a small example and this does not seem to work correctly:
# Original small model
enc_small = torch.nn.Sequential()
enc_small.add_module("dense_0", nn.Linear(2000, 4000))
torch.save(enc_small, 'orig_model.pth')
n_rep = 3
# Repeated model
enc_rep = torch.nn.Sequential()
enc_rep.add_module("dense_0", nn.Linear(2000 * n_rep, 4000 * n_rep))
torch.save(enc_rep, 'model_repeated.pth')
# Repeated model with pruning
enc_pruned = torch.nn.Sequential()
enc_pruned.add_module("dense_0", nn.Linear(2000 * n_rep, 4000 * n_rep))
# This mask is just zeroing out all elements out of the center blocks of 2000x4000
mask = np.zeros(encoder1.dense_0.weight.shape).astype(int)
for ir in range(n_rep + 1):
mask[((ir - 1) * 4000):(4000*ir),
((ir - 1) * 2000):(2000*ir)] = 1
torch.nn.utils.prune.custom_from_mask(enc_pruned.dense_0, name='weight', mask = torch.tensor(mask))
torch.save(enc_pruned, 'pruned_model.pth')
And here is how the models look one saved:
25629256 -rw-r--r-- 1 users 1.4G Aug 16 09:47 pruned_model.pth
25629254 -rw-r--r-- 1 users 275M Aug 16 09:46 model_repeated.pth
25629253 -rw-r--r-- 1 users 31M Aug 16 09:46 orig_model.pth
I do believe I am miserably missing something, but I don’t know what… It might be the behaviour of torch.nn.utils.prune.custom_from_mask
, that I am not understanding… Hence, I will really appreciate if you could offer me some help. Thank you so much!!!
P.S: I am in parallel experimenting with assigning the parameters as sparse matrices, but so far, to no avail…