when I use the following code to train my model in pytorch, I find the backward operation is extremely slow.
def sampling_prob(prob, label, num_neg): label = label - 1 batch, pred_len, l_m, = prob.shape, prob.shape, prob.shape - 1 init_label = np.zeros_like(label) init_prob = torch.zeros(size=(batch, pred_len, num_neg + batch)) # torch.Size([32, 12, 42]) for step in range(pred_len): _label = np.linspace(0, batch - 1, batch) _prob = torch.zeros(size=(batch, num_neg + batch)) batch_step_label = label[:, step] random_ig = random.sample(range(1, l_m + 1), num_neg) while len([lab for lab in batch_step_label if lab in random_ig]) != 0: random_ig = random.sample(range(1, l_m + 1), num_neg) global global_seed random.seed(global_seed) global_seed += 1 for k in range(batch): for i in range(num_neg + batch): if i < batch: _prob[k, i] = prob[k, step, batch_step_label[i]] else: _prob[k, i] = prob[k, step, random_ig[i - batch]] init_label[:, step] = _label init_prob[:, step, :] = _prob prob = self.model(train_input, train_m1, self.mat2s, train_m2t, train_len) prob_sample, label_sample = sampling_prob(prob, train_label, self.num_neg) loss_train = F.cross_entropy(prob_sample, label_sample) loss_train.backward() optimizer.step() optimizer.zero_grad()
Indeed, I want to sample some data from the the output of the model, i.e.,prob, to calculate the loss and optimize the parameters. But the speed of the backward operation is extremely slow. I am wondering this is because a lot of indexing operations. What would be the reason for that? How can I solve this problem?