Sofmaxing over concatenated subsets of minibatch output logits yields 'variables modified by inplace operation' error

Hello,
in my current work in open-domain QA, I am trying to implement model that is processing mini-batch of documents and applies softmax over logits from multiple documents - subsets of mini-batch - computed via same model. For example:

Minibatch form:
=== logits from doc 0 ===
=== logits from doc 1 === 
=== logits from doc 2 ===
=== logits from doc 3 === 
=== logits from doc 4 ===
=== logits from doc 5 === 
=== logits from doc 6 === 
=== logits from doc 7 ===

(0,3),(3,8) are ranges to be normalized over
loss_term1 = CE(softmax(=== logits from doc 0 === logits from doc 1 ===  logits from doc 2))
loss_term2 = CE(softmax(=== logits from doc 3 === ... ===  logits from doc 7))
loss = loss_term1+loss_term2

Unfortunately, I am running into ā€˜one of the variables needed for gradient computation has been modified by an inplace operationā€™ error.

An example of iteration of the following code follows:

  1. Model processes batch of e.g. 20 documents, obtaining logits for each token in the document logprobs_S, logprobs_E = reader_model(batch). logprobs_S has size batch x longest_document_in_minibatch.
  2. Model iterates over each range of documents to be processed from set of all documents - e.g. paragraph_ranges=[(0,5),(5,20),(20,22)]. A subset of paragraph_ranges list is always processed via 1 mini-batch. Assume that since our model has minibatch size 20, it has processed documents currently_processed_range = (0,20).
  3. Since currently_processed_range = (0,20), the document subsets given by ranges paragraph_ranges[0] and paragraph_ranges[1] have been processed in this minibatch. Now I would like to compute 1 loss term for each document subset.
  4. Logits from logprobs_S,logprobs_E for each processed range are indexed as log_probs_S_per_ex = logprobs_S[to_be_processed_range[0] - os:to_be_processed_range[1] - os]. For instance for to_be_processed_range=(0,5),
    we will pick logprobs_S[0:5].
  5. As we would like to maximize the probability of logits, we need to compute cross entropy with ground truth answer span indices, which denote the offset from start of the file. Since we want to normalize over more files, we need to adjust these indices by offset given by number of logits in previous documents (done in part where gt_start += offset).
  6. At last, after picking the logits for documents from the same to_be_processed_range and adjusting ground_truths by offset, I unroll logits from these documents into 1 vector and compute loss for these logits.
  7. After processing all to_be_processed_range processed in minibatch (0,5),(5,20) in this example, we compute loss and occasionaly add gradients to the parameters (we accumulate gradients).
logprobs_S, logprobs_E = reader_model(batch)

print(f"Currently processed range {currently_processed_range}")
loss = 0
for example_ix, to_be_processed_range in enumerate(paragraph_ranges):

    # check whether paragraphs of to_be_processed_range have been processed
    # if they have been processed, to_be_processed_range is inside currently_processed_range
    if to_be_processed_range[0] >= currently_processed_range[0] and \
            to_be_processed_range[1] <= currently_processed_range[1]:

        # offset that must be substracted from to_be_processed_range
        # to align with currenty processed range index
        # e.g. if we previously processed range (0,20), 
        # and now we are processing (20,40), 
        # the 20th output is actually 0th output in logprobs_S/logprobs_E

        os = currently_processed_range[0]

        print(f"Processing range {to_be_processed_range[0]} - {to_be_processed_range[1]}.")
        print(f"Picking logit indices {to_be_processed_range[0] - os} - {to_be_processed_range[1] - os}.")
        # pick logits and answers corresponding to paragraphs aligned with 1 question
        log_probs_S_per_ex = logprobs_S[to_be_processed_range[0] - os:to_be_processed_range[1] - os]
        log_probs_E_per_ex = logprobs_E[to_be_processed_range[0] - os:to_be_processed_range[1] - os]
        answers_per_e = answers[to_be_processed_range[0] - os:to_be_processed_range[1] - os]

        # pick gt answer token indices for this example
        gt_start, gt_end = answers_per_e[flat_gt_index[example_ix]][0]

        # compute offset -  how many tokens (there is 1 logit for each token) are in documents before GT documents
        offset = log_probs_S_per_ex[:flat_gt_index[example_ix]].numel()

        # add offset to closed-domain ground truths
        gt_start += offset
        gt_end += offset

        # flatten paragraphs
        log_probs_S_per_ex = log_probs_S_per_ex.view(-1)
        log_probs_E_per_ex = log_probs_E_per_ex.view(-1)

        # compute loss over flatten paragraphs per example
        loss_s = reader_loss(log_probs_S_per_ex.unsqueeze(0), gt_start)
        loss_e = reader_loss(log_probs_E_per_ex.unsqueeze(0), gt_end)
        loss += loss_s + loss_e

        loss_elements += 1
        total_updates_so_far += 1

if loss != 0:
    total_loss += loss.item()
    loss.backward()
if loss_elements >= self.config["reader_batch_size"]:
    pbar_iterator.set_description(f"Training loss: {total_loss / total_updates_so_far}")
    loss_elements = 0
    torch.nn.utils.clip_grad_norm_(filter(lambda p: p.requires_grad, reader_model.parameters()), 5.)
    reader_optimizer.step()
    reader_optimizer.zero_grad()

Unfortunately, I am getting a following error immediately after calling loss.backward()for the first time. This happens only in case if I want to process more document subsets within one batch - e.g. (0,5), (5,20). So far I have failed to exactly figure out what to do to prevent it and I made it work by iterating only over 1 set of documents per mini-batch, which is a very slow solution. Is there any obvious problem I am missing?

/pytorch/torch/csrc/autograd/python_anomaly_mode.cpp:57: UserWarning: Traceback of forward call that caused the error:
  File "/home/researchpc/projects/QA/scripts/openqa/run_openret.py", line 77, in <module>
    framework.fit(query_enc_model, reader_model, reader_config=val_reader_config)
  File "/home/researchpc/projects/QA/scripts/openqa/openretrieval_faster.py", line 189, in fit
    train_iter, fields, reader_loss=CrossEntropyLoss())
  File "/home/researchpc/projects/QA/scripts/openqa/openretrieval_faster.py", line 420, in train_epoch_normalized_GTonly
    loss_e = reader_loss(log_probs_E_per_ex.unsqueeze(0), gt_end)
  File "/opt/anaconda3/envs/mlenv/lib/python3.6/site-packages/torch/nn/modules/module.py", line 547, in __call__
    result = self.forward(*input, **kwargs)
  File "/opt/anaconda3/envs/mlenv/lib/python3.6/site-packages/torch/nn/modules/loss.py", line 916, in forward
    ignore_index=self.ignore_index, reduction=self.reduction)
  File "/opt/anaconda3/envs/mlenv/lib/python3.6/site-packages/torch/nn/functional.py", line 1995, in cross_entropy
    return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction)
  File "/opt/anaconda3/envs/mlenv/lib/python3.6/site-packages/torch/nn/functional.py", line 1824, in nll_loss
    ret = torch._C._nn.nll_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index)

2019-09-30 14:31:23,235 ERROR root: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.LongTensor [1]] is at version 8; expected version 7 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!
Traceback (most recent call last):
  File "/home/researchpc/projects/QA/scripts/openqa/run_openret.py", line 80, in <module>
    raise be
  File "/home/researchpc/projects/QA/scripts/openqa/run_openret.py", line 77, in <module>
    framework.fit(query_enc_model, reader_model, reader_config=val_reader_config)
  File "/home/researchpc/projects/QA/scripts/openqa/openretrieval_faster.py", line 189, in fit
    train_iter, fields, reader_loss=CrossEntropyLoss())
  File "/home/researchpc/projects/QA/scripts/openqa/openretrieval_faster.py", line 427, in train_epoch_normalized_GTonly
    loss.backward()
  File "/opt/anaconda3/envs/mlenv/lib/python3.6/site-packages/torch/tensor.py", line 118, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/opt/anaconda3/envs/mlenv/lib/python3.6/site-packages/torch/autograd/__init__.py", line 93, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.LongTensor [1]] is at version 8; expected version 7 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!```

Hi,

One way to pinpoint such errors is to enable anomaly detection to know which operation in the forward caused the error in the backward.

You donā€™t seem to be using a lot of inplace operations. Maybe replace the loss += by loss = loss +`?

Thank for reply @albanD,
I have enabled the anomaly detection, otherwise there would be no stacktrace n.1.
After writing this long post after hours of debugging, I have replaced torch Crossentropy loss with my own implementation and now it seems to work :smile: .

loss_s = -F.log_softmax(log_probs_S_per_ex, 0)[gt_start[0]]
loss_e = -F.log_softmax(log_probs_E_per_ex, 0)[gt_end[0]]
loss += loss_s + loss_e

Following your advice, I have also tried changing code to loss = loss + loss_s + loss_e, but results are the same when using torchā€™s CrossEntropyLoss. Is there an inplace operation somewhere in CrossEntropyLoss loss?

Ho,

Sorry about that. I missed the first stacktrace!

That is interesting. Can you manage to reproduce that with a small code sample (20/30 lines)? I would be interesting in finding why this happens!

I have been trying until now, but everything simple I have created seems to workā€¦ I will make my code code public, in time, and add the reference here. (If I wonā€™t forget) in future.

That must be some very specific interaction in your code then :confused:
Have you tried .clone() on each input and output of the CrossEntropyLoss? If one of them fixes the issue, then it is this one that was modified inplace :smiley: like loss_s = reader_loss(log_probs_S_per_ex.unsqueeze(0).clone(), gt_start.clone()).clone()

1 Like

Interestingly, cloning gt really works!

loss_s = reader_loss(log_probs_S_per_ex.unsqueeze(0), gt_start.clone())
loss_e = reader_loss(log_probs_E_per_ex.unsqueeze(0), gt_end.clone())

and after closer investigation, I have found the guilty! It is the += operator applied to offset
Changing

offset = log_probs_S_per_ex[:flat_gt_index[example_ix]].numel()
# add offset to closed-domain ground truths
gt_start += offset
gt_end += offset

to

offset = log_probs_S_per_ex[:flat_gt_index[example_ix]].numel()
# add offset to closed-domain ground truths
gt_start = gt_start+ offset
gt_end = gt_end+ offset

solves this issue (without any explicit cloning)!

Nice catch !

Cloning has the side effect of making inplace operations out of place. So it can help tracking this kind of issues!

1 Like