I am trying to implement global, unstructured pruning algorithm closely based on The Lottery Ticket Hypothesis: Finding Sparse, Trainable Neural Networks, Frankle et al. and Comparing Rewinding and Fine-tuning in Neural Network Pruning, Renda et al. I am also aware that there exists PyTorch’s pruning in torch.nn.utils.prune
. My motivation for implementing things from scratch are learning and research based.
I use a toy LeNet-300-100 dense network with MNIST for this task. The class architecture is:
class LeNet300(nn.Module):
def __init__(self):
super().__init__()
# Define layers-
self.fc1 = nn.Linear(in_features = 28 * 28 * 1, out_features = 300)
self.fc2 = nn.Linear(in_features = 300, out_features = 100)
self.output = nn.Linear(in_features = 100, out_features = 10)
# self.weights_initialization()
def forward(self, x):
out = F.leaky_relu(self.fc1(x))
out = F.leaky_relu(self.fc2(out))
return self.output(out)
The function to train for one epoch is:
def train_model_progress(model, train_loader, train_dataset):
'''
Function to perform one epoch of training by using 'train_loader'.
Returns loss and number of correct predictions for this epoch.
'''
running_loss = 0.0
running_corrects = 0.0
model.train()
with tqdm(train_loader, unit = 'batch') as tepoch:
for images, labels in tepoch:
tepoch.set_description(f"Training: ")
# images = images.view(-1, 28 * 28 * 1)
images = images.to(device)
labels = labels.to(device)
# Get model predictions-
outputs = model(images)
# Compute loss-
J = loss(outputs, labels)
# Empty accumulated gradients-
optimizer.zero_grad()
# Perform backprop-
J.backward()
# Remove/Zero-out all gradients corresponding to pruned connections-
# for name, param in model.named_parameters():
for param in model.parameters():
'''
if 'mask' in name:
continue
'''
tensor = param.data.cpu().numpy()
grad_tensor = param.grad.data.cpu().numpy()
grad_tensor = np.where(tensor == 0., 0., grad_tensor)
param.grad.data = torch.from_numpy(grad_tensor).to(device)
# Update parameters-
optimizer.step()
'''
global step
optimizer.param_groups[0]['lr'] = custom_lr_scheduler.get_lr(step)
step += 1
'''
# Compute model's performance statistics-
running_loss += J.item() * images.size(0)
_, predicted = torch.max(outputs, 1)
running_corrects += torch.sum(predicted == labels.data)
tepoch.set_postfix(
loss = running_loss / len(train_dataset),
accuracy = (running_corrects.double().cpu().numpy() / len(train_dataset)) * 100
)
train_loss = running_loss / len(train_dataset)
train_acc = (running_corrects.double() / len(train_dataset)) * 100
return train_loss, train_acc.cpu().numpy()
The focus is on these critical lines:
for param in model.parameters():
'''
if 'mask' in name:
continue
'''
tensor = param.data
# Or: tensor = param.data.cpu().numpy()
grad_tensor = param.grad.data
# Or: grad_tensor = param.grad.data.cpu().numpy()
grad_tensor = torch.where(tensor == 0., 0., grad_tensor)
# Or: grad_tensor = np.where(tensor == 0., 0., grad_tensor)
param.grad.data = grad_tensor.to(device)
# Or: param.grad.data = torch.from_numpy(grad_tensor).to(device)
The idea is simple: iterate through trainable parameters of the model and copy the parameters and its corresponding computed gradients in tensor
and grad_tensor
. Then, for all positions within grad_tensor
, assign 0 where the corresponding weights are pruned (or, are 0s).
The problem is that this doesn’t behave as expected. Sometimes, the sparsity is preserved and most of the times, its not. For example, an unpruned LeNet-300-100 has 266610 non-zero parameters. Whereas, after 20% of global pruning, the number of surviving parameters = 213288. But, when retraining the pruned model, instead of having 213288 parameters, the model reverts back to 266610 parameters, therefore destroying pruning.
What am I missing?