A large number of indexing operations result in very slow back propagation in pytorch

when I use the following code to train my model in pytorch, I find the backward operation is extremely slow.

def sampling_prob(prob, label, num_neg):
    label = label - 1  
    batch, pred_len, l_m, = prob.shape[0], prob.shape[1], prob.shape[2] - 1
    init_label = np.zeros_like(label)
    init_prob = torch.zeros(size=(batch, pred_len, num_neg + batch))  # torch.Size([32, 12, 42])

    for step in range(pred_len):
        _label = np.linspace(0, batch - 1, batch)
        _prob = torch.zeros(size=(batch, num_neg + batch))

        batch_step_label = label[:, step]

        random_ig = random.sample(range(1, l_m + 1), num_neg)  
        while len([lab for lab in batch_step_label if lab in random_ig]) != 0:  
            random_ig = random.sample(range(1, l_m + 1), num_neg)

        global global_seed
        global_seed += 1

        for k in range(batch):
            for i in range(num_neg + batch):
                if i < batch:
                    _prob[k, i] = prob[k, step, batch_step_label[i]]
                    _prob[k, i] = prob[k, step, random_ig[i - batch]]

        init_label[:, step] = _label
        init_prob[:, step, :] = _prob

prob = self.model(train_input, train_m1, self.mat2s, train_m2t, train_len)
prob_sample, label_sample = sampling_prob(prob, train_label, self.num_neg)
loss_train = F.cross_entropy(prob_sample, label_sample)

Indeed, I want to sample some data from the the output of the model, i.e.,prob, to calculate the loss and optimize the parameters. But the speed of the backward operation is extremely slow. I am wondering this is because a lot of indexing operations. What would be the reason for that? How can I solve this problem?

This records all _prob tensor “versions”, the framework doesn’t know that writes are non-overlapping and intermediate states are discardable. Then you have a long chain of autograd nodes to be processed during backward().

Workaround is to use torch.gather to set all values at once. Another option is to rewrite this with torch.where & torch.cat, similarly avoiding assignments to a buffer.

1 Like