when I use the following code to train my model in pytorch, I find the backward operation is extremely slow.

```
def sampling_prob(prob, label, num_neg):
label = label - 1
batch, pred_len, l_m, = prob.shape[0], prob.shape[1], prob.shape[2] - 1
init_label = np.zeros_like(label)
init_prob = torch.zeros(size=(batch, pred_len, num_neg + batch)) # torch.Size([32, 12, 42])
for step in range(pred_len):
_label = np.linspace(0, batch - 1, batch)
_prob = torch.zeros(size=(batch, num_neg + batch))
batch_step_label = label[:, step]
random_ig = random.sample(range(1, l_m + 1), num_neg)
while len([lab for lab in batch_step_label if lab in random_ig]) != 0:
random_ig = random.sample(range(1, l_m + 1), num_neg)
global global_seed
random.seed(global_seed)
global_seed += 1
for k in range(batch):
for i in range(num_neg + batch):
if i < batch:
_prob[k, i] = prob[k, step, batch_step_label[i]]
else:
_prob[k, i] = prob[k, step, random_ig[i - batch]]
init_label[:, step] = _label
init_prob[:, step, :] = _prob
prob = self.model(train_input, train_m1, self.mat2s, train_m2t, train_len)
prob_sample, label_sample = sampling_prob(prob, train_label, self.num_neg)
loss_train = F.cross_entropy(prob_sample, label_sample)
loss_train.backward()
optimizer.step()
optimizer.zero_grad()
```

Indeed, I want to sample some data from the the output of the model, i.e.,prob, to calculate the loss and optimize the parameters. But the speed of the backward operation is extremely slow. I am wondering this is because a lot of indexing operations. What would be the reason for that? How can I solve this problem?