Unexpected Gradient Presence in Zero-valued Positions of Trainable Parameters after Value Replacement during Training

Hi there.

I tried to modify the values of trainable parameters during the training process and use them for the next epoch.

However, I observed that after changing the parameters, the zeros in the parameters are also updated during the next batch update, which goes against intuition because multiplying any element by zero should not produce any gradient.

I tried to implement the minimum runnable code and reproduced my problem. The code can be run directly:

import torch
import torch.nn as nn
import torch.optim as optim

class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.adj = nn.Parameter(torch.Tensor(5, 5).fill_(0.5), requires_grad=False)  # fixed adj
        self.adj_mask = nn.Parameter(torch.ones(5, 5), requires_grad=True)  # trainable adj_mask

    def forward(self):
        masked_adj = torch.mul(self.adj,self.adj_mask) # mask adj with adj_mask
        return masked_adj

def print_tensor_info(tensor):
    num_zeros = (tensor == 0).sum().item()
    num_non_zeros = tensor.numel() - num_zeros
    max_val, min_val = tensor.max(), tensor.min()
    has_nan = torch.isnan(tensor).any().item()

    print(f"Zero values: {num_zeros}")
    print(f"Non-zero values: {num_non_zeros}")
    print(f"Max value: {max_val.item()}")
    print(f"Min value: {min_val.item()}")
    print(f"Has NaN: {has_nan}")

# prune the mask
def prune_mask(mask, percentage):
    with torch.no_grad():
        # print("before prune")
        # print_tensor_info(mask)
        adj_mask_tensor = mask.flatten()
        # print("before_adj_mask_tensor.shape: ", adj_mask_tensor.shape)
        nonzero = torch.abs(adj_mask_tensor) > 0
        adj_mask = adj_mask_tensor[nonzero]  # 13264 - 2708
        # print("before_adj_mask.shape: ", adj_mask.shape)
        # print(adj_mask)
        adj_total = adj_mask.shape[0]
        adj_y, adj_i = torch.sort(adj_mask.abs())
        adj_thre_index = int(adj_total * percentage)
        adj_thre = adj_y[adj_thre_index]
        # print("adj_thre", adj_thre)
        abs_values = torch.abs(mask)
        index = abs_values >= adj_thre
        mask.data[index] = 1
        mask.data[~index] = 0
        # print("After prune")
        # print_tensor_info(mask)
    # print("-------")

model = SimpleModel()

optimizer = optim.Adam([model.adj_mask], lr=10)

# input_data = torch.randn(5, 5)
target = torch.randn(5, 5)

pruneflag = False
for epoch in range(41):

    # if pruneflag:
    #     print("Check whether adj_mask has been pruned before forward propagation")
    #     print_tensor_info(model.adj_mask)
    masked_adj = model()
    loss = nn.MSELoss()(masked_adj, target)

    # if pruneflag:
    #     print("Check whether adj_mask has been pruned before backward propagation")
    #     print_tensor_info(model.adj_mask)
    # if pruneflag:
    #     print("Check whether the adj_mask gradient has been cleared before backpropagation")
    #     print_tensor_info(model.adj_mask.grad)
        print("Check adj_mask gradient after backpropagation")
        # optimizer.zero_grad()
        # print_tensor_info(model.adj_mask.grad)
    if (pruneflag):
        print("Check adj_mask after backpropagation")
        pruneflag = False

    # prune every 20 epoch
    if epoch % 20 == 0:
        print("----epoch" + str(epoch))
        prune_mask(model.adj_mask, 0.2)
        pruneflag = True

My question are:
(1) After modifying the adj_mask tensor, during the backpropagation, I observe that the gradients on the adj_mask are not updating as intended—elements multiplied by zero through a Hadamard product are unexpectedly receiving gradients, which results in incorrect updates.

(2) Furthermore, despite manually setting the gradients to zero, the adj_mask continues to get updated following optimizer.step() in my code (commented out parts of my code). I hypothesize that this could be related to the momentum inherent in the Adam optimizer because the issue does not occur when I switch to SGD. Nonetheless, this hypothesis doesn’t clarify the underlying reason for the first question.

Thanks for reading. Looking forward to your help