Manually modified parameter during trainning, gradients are not being correctly updated. Help

Hello.

During training, after replacing a parameter, gradients are not being correctly updated.
Specifically:

  1. There is a fixed tensor adj_mat without gradients.
  2. A mask tensor adj_mask with the same dimensions as adj_mask, initially filled with 0s and 1s, which has gradients.
  3. Element-wise multiplication of adj_mat and adj_mask results in adj.
  4. Forward and backward propagation, along with parameter updates, are performed based on adj (updating adj_mask).

Initially, adj_mask remains constant, but every 20 epochs, I eliminate the smallest 5% of non-zero elements in adj_mask to generate a new adj_mask. This means that every 20 epochs, adj_mask will have fewer non-zero elements by 5%.

However, I observe that after changing adj_mask, during the next batch update, the zeros in adj_mask are also being updated, which is counterintuitive because multiplying any element by zero should not produce any gradient. Can anyone explain why this might be happening and what I should do?

I welcome any assistance from everyone, this is truly crucial for me, and I am deeply grateful.

This question is related to Why do manually modified parameters in training process revert to their original values after optimizer.step() is executed?

Simplified code:

    def update_norm_adj(self,prune_percent = 0.05):
    # Prune adj_mask, executed every 20 epochs
        with torch.no_grad():
            adj_mask_tensor = self.adj_mask_train.flatten()
            nonzero = torch.abs(adj_mask_tensor) > 0
            adj_mask = adj_mask_tensor[nonzero] 
            adj_total = adj_mask.shape[0]
            adj_y, adj_i = torch.sort(adj_mask.abs())
            adj_thre_index = int(adj_total * prune_percent)
            adj_thre = adj_y[adj_thre_index]
            abs_values = torch.abs(self.adj_mask_train)
            mask = abs_values >= adj_thre
            self.adj_mask_train.data[mask] = 1
            self.adj_mask_train.data[~mask] = 0
        self.pruneflag = True
        print("prune success!")
    def forward(self):
        if self.lotteryflag and self.pruneflag
            #In each forward phase, the adjacency matrix is recalculated
            dense_tensor = self.adj_mat
            adj = torch.mul(dense_tensor, self.adj_mask_train)
            self.norm_adj_matrix = self.torch_normalize_adj(adj)

        all_embeddings = self.get_ego_embeddings()  
        feat_mat = self.dropout_sp_mat(self.feat_mat)
        embeddings_list = [all_embeddings]  

        for layer_idx in range(self.n_layers):
            all_embeddings = torch.mm(self.norm_adj_matrix, all_embeddings)
            embeddings_list.append(all_embeddings)
        lightgcn_all_embeddings = torch.stack(embeddings_list, dim=1)
        lightgcn_all_embeddings = torch.mean(lightgcn_all_embeddings, dim=1)

        user_all_embeddings, item_all_embeddings = torch.split(
            lightgcn_all_embeddings, [self.n_users, self.n_items]
        )
        return user_all_embeddings, item_all_embeddings
    def torch_normalize_adj(self,adj):
        adj = adj + torch.eye(adj.shape[0]).cuda()
        rowsum = adj.sum(1)
        d_inv_sqrt = torch.pow(rowsum, -0.5).flatten()
        d_inv_sqrt[torch.isinf(d_inv_sqrt)] = 0.0
        d_mat_inv_sqrt = torch.diag(d_inv_sqrt).cuda()
        result = adj.mm(d_mat_inv_sqrt).t().mm(d_mat_inv_sqrt)
        return result

    def _train_epoch(self, train_data,aux_train_data, epoch_idx, show_progress=False):
        # Train the model in an epoch

        self.model.train()
        total_loss = None
        iter_data = ( 
            tqdm(
                train_data,
                total=len(train_data),
                ncols=100,
                desc=set_color(f"Train {epoch_idx:>5}", "pink"),
            )
            if show_progress
            else train_data
        )
        if not self.config["single_spec"] and train_data.shuffle:
            train_data.sampler.set_epoch(epoch_idx)

        scaler = amp.GradScaler(enabled=self.enable_scaler)
        batchcount = 0

        dataloader_iterator = iter(aux_train_data)

        for batch_index, interaction in enumerate(iter_data):
            try:
                interaction_aux = next(dataloader_iterator)
            except StopIteration: 
                dataloader_iterator = iter(aux_train_data) #新建一个
                interaction_aux= next(dataloader_iterator)

            interaction = interaction.to(self.device)
            interaction_aux = interaction_aux.to(self.device)
            self.optimizer.zero_grad()
            sync_loss = 0
            if not self.config["single_spec"]:
                self.set_reduce_hook()
                sync_loss = self.sync_grad_loss()
            with torch.autocast(device_type=self.device.type, enabled=self.enable_amp):
                losses = self.model.calculate_loss(interaction, interaction_aux)
            if isinstance(losses, tuple):
                loss = sum(losses)
                loss_tuple = tuple(per_loss.item() for per_loss in losses)
                total_loss = (
                    loss_tuple
                    if total_loss is None
                    else tuple(map(sum, zip(total_loss, loss_tuple)))
                )
            else:
                loss = losses
                total_loss = (
                    losses.item() if total_loss is None else total_loss + losses.item()
                )
            self._check_nan(loss)
            # scaler.scale(loss + sync_loss).backward(retain_graph=True)
            scaler.scale(loss + sync_loss).backward()
            if self.clip_grad_norm:
                clip_grad_norm_(self.model.parameters(), **self.clip_grad_norm)
            scaler.step(self.optimizer)
            scaler.update()
            if self.gpu_available and show_progress:
                iter_data.set_postfix_str(
                    set_color("GPU RAM: " + get_gpu_usage(self.device), "yellow")
                )
            
   

        if self.model.lotteryflag and epoch_idx % 10 == 0 and epoch_idx > 0:
            print("prune!")
            self.model.update_norm_adj()
            self.model.zero_grad()
        return total_loss

(more code can be found here: Why do manually modified parameters in training process revert to their original values after optimizer.step() is executed?)

operation result:

# Before pruning
before_adj_mask.shape:  torch.Size([161616])
tensor([0.6044, 0.6164, 0.6389,  ..., 0.7365, 1.3904, 0.8289], device='cuda:0')
Number of zeros: 6747593
Number of ones: 153536
prune success!

# After pruning
after_adj_mask_tensor.shape:  torch.Size([6901129])
bfater_adj_mask.shape:  torch.Size([153536])
tensor([1., 1., 1.,  ..., 1., 1., 1.], device='cuda:0')

# Before the next backward
  Zero values: 6747593
  Non-zero values: 153536
  Max value: 1.0
  Min value: 0.0
  Has NaN: False
# After the next backward
  Zero values: 6739513
  Non-zero values: 161616
  Max value: 1.012004017829895
  Min value: -0.012804072350263596

The crux of the problem is that adj_mask_train updates the zero value, and what I actually do is:adj = torch.mul(dense_tensor, self.adj_mask_train),it should not have gradients.

Although I have tried many methods to solve it, I still can’t find where the problem lies. I really need the professional ability of everyone on the forum . Thank you very much.