Here is the forward()
method of my pytorch model:
def forward(self, x, output_type, *unused_args, **unused_kwargs):
gru_output, gru_hn = self.gru(x)
# Decoder (Graph Adjacency Reconstruction)
for data_batch_idx in range(x.shape[0]):
pred = self.decoder(gru_output[data_batch_idx, -1, :]) # gru_output[-1] => only take last time-step
pred_graph_adj = pred.reshape(1, -1) if data_batch_idx == 0 else torch.cat((pred_graph_adj, pred.reshape(1, -1)), dim=0)
if output_type == "discretize":
bins = torch.tensor(self.model_cfg['output_bins']).reshape(-1, 1)
num_bins = len(bins)-1
bins = torch.concat((bins[:-1], bins[1:]), dim=1)
discretize_values = np.linspace(-1, 1, num_bins)
for lower, upper, discretize_value in zip(bins[:, 0], bins[:, 1], discretize_values):
pred_graph_adj = torch.where((pred_graph_adj <= upper) & (pred_graph_adj > lower), discretize_value, pred_graph_adj)
pred_graph_adj = torch.where(pred_graph_adj < bins.min(), bins.min(), pred_graph_adj)
return pred_graph_adj
And here is the snippet of training:
pred = self.forward(x, output_type=self.model_cfg['output_type'])
batch_loss = self.loss_fn(pred, y)
self.optimizer.zero_grad()
batch_loss.backward()
self.optimizer.step()
self.scheduler.step()
- When
output_type
is not"discretize"
(not usingtorch.where
),sum([p.grad.sum() for p in self.decoder.parameters()])
will be non-zero.- But When
output_type
is"discretize"
(usingtorch.where
),sum([p.grad.sum() for p in self.decoder.parameters()])
will be zero.
- But When
- I’ve check the
batch_loss
, it’s not zero. - I’ve check all the
require_grad
of weight of model, they are True. - I’ve check computational graph,
pred
andbatch_loss
are connect to model’s weight.
My questions are:
- Does using
torch.where
cause the model’s parameter gradients to become zero? - If
torch.where
won’t cause that, what’s other possible reasons?
PS. I ask the same question on stackoverflow: python - Does using torch.where cause the model's parameter gradients to become zero? - Stack Overflow