Why do manually modified parameters in training process revert to their original values after optimizer.step() is executed?


During the training process, I attempt to manually modify parameters under special condition, but find that they automatically revert back to their previous states. How can I prevent this from happening in PyTorch?

The code:

    def _train_epoch(self, train_data,aux_train_data, epoch_idx, show_progress=False):
        r"""Train the model in an epoch

            train_data (DataLoader): The train data.
            epoch_idx (int): The current epoch id.
            loss_func (function): The loss function of :attr:`model`. If it is ``None``, the loss function will be
                :attr:`self.model.calculate_loss`. Defaults to ``None``.
            show_progress (bool): Show the progress of training epoch. Defaults to ``False``.

            float/tuple: The sum of loss returned by all batches in this epoch. If the loss in each batch contains
            multiple parts and the model return these multiple parts loss instead of the sum of loss, it will return a
            tuple which includes the sum of loss in each part.
        total_loss = None
        iter_data = ( #
                desc=set_color(f"Train {epoch_idx:>5}", "pink"),
            if show_progress
            else train_data
        if not self.config["single_spec"] and train_data.shuffle:

        scaler = amp.GradScaler(enabled=self.enable_scaler)
        batchcount = 0

        dataloader_iterator = iter(aux_train_data) 

        for batch_index, interaction in enumerate(iter_data):
                interaction_aux = next(dataloader_iterator)
            except StopIteration:
                dataloader_iterator = iter(aux_train_data) 
                interaction_aux= next(dataloader_iterator)
        # for interaction, interaction_aux in zip(iter_data,aux_train_data):

            interaction = interaction.to(self.device)
            interaction_aux = interaction_aux.to(self.device)
            sync_loss = 0
            if not self.config["single_spec"]:
                sync_loss = self.sync_grad_loss()
            with torch.autocast(device_type=self.device.type, enabled=self.enable_amp):
                # losses = self.model.calculate_loss(interaction,interaction_aux,centroid_emb)
                losses = self.model.calculate_loss(interaction, interaction_aux)
            if isinstance(losses, tuple):
                loss = sum(losses)
                loss_tuple = tuple(per_loss.item() for per_loss in losses)
                total_loss = (
                    if total_loss is None
                    else tuple(map(sum, zip(total_loss, loss_tuple)))
                loss = losses
                total_loss = (
                    losses.item() if total_loss is None else total_loss + losses.item()
            scaler.scale(loss + sync_loss).backward(retain_graph=True)
            if self.clip_grad_norm:
                clip_grad_norm_(self.model.parameters(), **self.clip_grad_norm)
            if self.gpu_available and show_progress:
                    set_color("GPU RAM: " + get_gpu_usage(self.device), "yellow")

        if self.model.lotteryflag and epoch_idx % 10 == 0 and epoch_idx > 10:

        return total_loss
    def update_norm_adj(self,prune_percent = 0.05):
        with torch.no_grad():
            adj_mask_tensor = self.adj_mask_train.flatten()
            print("adj_mask_tensor.shape: ", adj_mask_tensor.shape)
            nonzero = torch.abs(adj_mask_tensor) > 0
            adj_mask = adj_mask_tensor[nonzero] 
            print("adj_mask.shape: ", adj_mask.shape)
            adj_total = adj_mask.shape[0]
            adj_y, adj_i = torch.sort(adj_mask.abs())
            adj_thre_index = int(adj_total * prune_percent)
            adj_thre = adj_y[adj_thre_index]
            abs_values = torch.abs(self.adj_mask_train)
            mask = abs_values >= adj_thre
            self.adj_mask_train.data[mask] = 1
            self.adj_mask_train.data[~mask] = 0

        if self.adj_mat.is_sparse:
            dense_tensor = self.adj_mat.to_dense()
            dense_tensor = self.adj_mat
        # print("dense_tensor.shape: ", dense_tensor.shape)
        # print("adj_mask.shape: ", adj_mask.shape)
        zeros_count = (self.adj_mask_train == 0).sum().item()
        ones_count = (self.adj_mask_train == 1).sum().item()
        print(f"Number of zeros: {zeros_count}")
        print(f"Number of ones: {ones_count}")
        adj = torch.mul(dense_tensor, self.adj_mask_train)
        self.norm_adj_matrix = self.torch_normalize_adj(adj)

Specifically, I aim to perform pruning once at fixed epoch intervals, which involves modifying the learnable parameters:

self.adj_mask_train = torch.nn.Parameter(self.generate_adj_mask(self.generate_daj_mat()),requires_grad=True)

However, I am encountering an issue where, although pruning appears successful and I can indeed access the modified parameters via self.model.adj_mask_train, upon entering the next epoch and reaching the part of the code:


I find that the parameters adj_mask_train in the model have reverted back to their pre-pruning state (the initial state without pruning) . What I need is that adj_mask_train can be gradually iterated to the state I want. However, the current situation seems to be that all my modifications to the parameters will return to the original state in the next epoch.

My problem may be similar to Set nn.Parameter during training. But I couldn’t find an answer to this question. It seems that the last content added by the questioner is the possible cause of this problem?

Could you add the missing code pieces to make it a minimal and executable code snippet reproducing the issue, please?

Of course, if there are any confusing parts of my code or some key parts need to be added, please feel free to point them out

    def forward(self):
        all_embeddings = self.get_ego_embeddings()  
        feat_mat = self.dropout_sp_mat(self.feat_mat)
        embeddings_list = [all_embeddings]  
        for layer_idx in range(self.n_layers):
            # all_embeddings = torch.sparse.mm(self.norm_adj_matrix, all_embeddings)
            all_embeddings = torch.mm(self.norm_adj_matrix, all_embeddings)
        gcn_all_embeddings = torch.stack(embeddings_list, dim=1)
        gcn_all_embeddings = torch.mean(lightgcn_all_embeddings, dim=1)

        user_all_embeddings, item_all_embeddings = torch.split(
            gcn_all_embeddings, [self.n_users, self.n_items]
        return user_all_embeddings, item_all_embeddings
    def calculate_loss(self, interaction,core_interaction,centroid_emb = None):
        # clear the storage variable when training
        if self.restore_user_e is not None or self.restore_item_e is not None:
            self.restore_user_e, self.restore_item_e = None, None

        user = interaction[self.USER_ID]
        pos_item = interaction[self.ITEM_ID]
        neg_item = interaction[self.NEG_ITEM_ID]

        user_all_embeddings, item_all_embeddings = self.forward()
        u_embeddings = user_all_embeddings[user]
        pos_embeddings = item_all_embeddings[pos_item]
        neg_embeddings = item_all_embeddings[neg_item]

        # calculate BPR Loss
        pos_scores = torch.mul(u_embeddings, pos_embeddings).sum(dim=1)
        neg_scores = torch.mul(u_embeddings, neg_embeddings).sum(dim=1)

        u_ego_embeddings = user_all_embeddings[user]
        pos_ego_embeddings = item_all_embeddings[pos_item]
        neg_ego_embeddings = item_all_embeddings[neg_item]
        reg_loss = self.reg_loss(

        loss = mf_loss + self.reg_weight * reg_loss

        return loss
    def torch_normalize_adj(self,adj):
        adj = adj + torch.eye(adj.shape[0]).cuda()
        rowsum = adj.sum(1)
        d_inv_sqrt = torch.pow(rowsum, -0.5).flatten()
        d_inv_sqrt[torch.isinf(d_inv_sqrt)] = 0.0
        d_mat_inv_sqrt = torch.diag(d_inv_sqrt).cuda()
        result = adj.mm(d_mat_inv_sqrt).t().mm(d_mat_inv_sqrt)
        return result
    def generate_daj_mat(self):
        inter_M = self.interaction_matrix # coo
        print(inter_M.shape) # 944 * 1483
        users, items = inter_M.nonzero() 
        row = np.concatenate([users, items + self.n_users], axis=0) 
        column = np.concatenate([items + self.n_users, users], axis=0) 
        adj_mat = sp.coo_matrix((np.ones(row.shape), np.stack([row, column], axis=0)),
                                shape=(self.n_users + self.n_items, self.n_users + self.n_items),
        return adj_mat  

    def generate_adj_mask(self, input_adj):

        sparse_adj = input_adj.tocoo()
        values = torch.from_numpy(sparse_adj.data)
        row_indices = torch.from_numpy(sparse_adj.row)
        col_indices = torch.from_numpy(sparse_adj.col)
        tensor_sparse = torch.sparse_coo_tensor(
            indices=torch.stack([row_indices, col_indices], dim=0),
        dense_tensor = tensor_sparse.to_dense()

        zeros = torch.zeros_like(dense_tensor)
        ones = torch.ones_like(dense_tensor)
        mask = torch.where(dense_tensor != 0, ones, zeros)
        return mask

The forward propagation stage and the calculation loss stage correspond to the first and second pieces of code respectively, and the remaining codes are some general functions.

Best wishes!

Does anyone know what the problem is?

Hello, I wrote similar code and reproduced my problem. Now the code can be run directly.
Due to time reasons, the numerical settings may not be very rigorous, but the problem is the same.
Here the code:

import torch
import torch.nn as nn
import torch.optim as optim

class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.adj = nn.Parameter(torch.Tensor(5, 5).fill_(0.5), requires_grad=False)  # fixed adj
        self.adj_mask = nn.Parameter(torch.ones(5, 5), requires_grad=True)  # trainable adj_mask

    def forward(self):
        masked_adj = torch.mul(self.adj,self.adj_mask) # mask adj with adj_mask
        return masked_adj

def print_tensor_info(tensor):
    num_zeros = (tensor == 0).sum().item()
    num_non_zeros = tensor.numel() - num_zeros
    max_val, min_val = tensor.max(), tensor.min()
    has_nan = torch.isnan(tensor).any().item()

    print(f"Zero values: {num_zeros}")
    print(f"Non-zero values: {num_non_zeros}")
    print(f"Max value: {max_val.item()}")
    print(f"Min value: {min_val.item()}")
    print(f"Has NaN: {has_nan}")

# prune the mask
def prune_mask(mask, percentage):
    with torch.no_grad():
        # print("before prune")
        # print_tensor_info(mask)
        adj_mask_tensor = mask.flatten()
        # print("before_adj_mask_tensor.shape: ", adj_mask_tensor.shape)
        nonzero = torch.abs(adj_mask_tensor) > 0
        adj_mask = adj_mask_tensor[nonzero]  # 13264 - 2708
        # print("before_adj_mask.shape: ", adj_mask.shape)
        # print(adj_mask)
        adj_total = adj_mask.shape[0]
        adj_y, adj_i = torch.sort(adj_mask.abs())
        adj_thre_index = int(adj_total * percentage)
        adj_thre = adj_y[adj_thre_index]
        # print("adj_thre", adj_thre)
        abs_values = torch.abs(mask)
        index = abs_values >= adj_thre
        mask.data[index] = 1
        mask.data[~index] = 0
        # print("After prune")
        # print_tensor_info(mask)
    # print("-------")

model = SimpleModel()

optimizer = optim.Adam([model.adj_mask], lr=10)

# input_data = torch.randn(5, 5)
target = torch.randn(5, 5)

pruneflag = False
for epoch in range(41):

    # if pruneflag:
    #     print("Check whether adj_mask has been pruned before forward propagation")
    #     print_tensor_info(model.adj_mask)
    masked_adj = model()
    loss = nn.MSELoss()(masked_adj, target)

    # if pruneflag:
    #     print("Check whether adj_mask has been pruned before backward propagation")
    #     print_tensor_info(model.adj_mask)
    # if pruneflag:
    #     print("Check whether the adj_mask gradient has been cleared before backpropagation")
    #     print_tensor_info(model.adj_mask.grad)
        print("Check adj_mask gradient after backpropagation")
        # optimizer.zero_grad()
        # print_tensor_info(model.adj_mask.grad)
    if (pruneflag):
        print("Check adj_mask after backpropagation")
        pruneflag = False

    # prune every 20 epoch
    if epoch % 20 == 0:
        print("----epoch" + str(epoch))
        prune_mask(model.adj_mask, 0.2)
        pruneflag = True

After modifying the adj_mask tensor, during the backpropagation, I observe that the gradients on the adj_mask are not updating as intended—specifically, elements multiplied by zero through a Hadamard product are unexpectedly receiving gradients, which results in incorrect gradient computations.

Furthermore, despite manually setting the gradients to zero, the adj_mask continues to get updated following optimizer.step() in my code (you can refer to the commented out parts of my code). I hypothesize that this could be related to the momentum inherent in the Adam optimizer because the issue does not occur when I switch to SGD. Nonetheless, this hypothesis doesn’t clarify the underlying reason for the first question.

Thank you in advance for taking the time to read through my query and for any assistance you can provide.