Does using torch.where cause the model's parameter gradients to become zero?

Here is the forward() method of my pytorch model:

    def forward(self, x, output_type, *unused_args, **unused_kwargs):
        gru_output, gru_hn = self.gru(x)
        # Decoder (Graph Adjacency Reconstruction)
        for data_batch_idx in range(x.shape[0]):
            pred = self.decoder(gru_output[data_batch_idx, -1, :])  # gru_output[-1] => only take last time-step
            pred_graph_adj = pred.reshape(1, -1) if data_batch_idx == 0 else torch.cat((pred_graph_adj, pred.reshape(1, -1)), dim=0)
        if output_type == "discretize":
            bins = torch.tensor(self.model_cfg['output_bins']).reshape(-1, 1)
            num_bins = len(bins)-1
            bins = torch.concat((bins[:-1], bins[1:]), dim=1)
            discretize_values = np.linspace(-1, 1, num_bins)
            for lower, upper, discretize_value in zip(bins[:, 0], bins[:, 1], discretize_values):
                pred_graph_adj = torch.where((pred_graph_adj <= upper) & (pred_graph_adj > lower), discretize_value, pred_graph_adj)
            pred_graph_adj = torch.where(pred_graph_adj < bins.min(), bins.min(), pred_graph_adj)

        return pred_graph_adj

And here is the snippet of training:

                pred = self.forward(x, output_type=self.model_cfg['output_type'])
                batch_loss = self.loss_fn(pred, y)
                self.optimizer.zero_grad()
                batch_loss.backward()
                self.optimizer.step()
                self.scheduler.step()
  1. When output_type is not "discretize" (not using torch.where), sum([p.grad.sum() for p in self.decoder.parameters()]) will be non-zero.
    • But When output_type is "discretize" (using torch.where), sum([p.grad.sum() for p in self.decoder.parameters()]) will be zero.
  2. I’ve check the batch_loss, it’s not zero.
  3. I’ve check all the require_grad of weight of model, they are True.
  4. I’ve check computational graph, pred and batch_loss are connect to model’s weight.

My questions are:

  1. Does using torch.where cause the model’s parameter gradients to become zero?
  2. If torch.where won’t cause that, what’s other possible reasons?

PS. I ask the same question on stackoverflow: python - Does using torch.where cause the model's parameter gradients to become zero? - Stack Overflow