Recursive call to backward()

SETUP
pytorch: 1.5.1
huggingface transformers: 3.0.2
python: 3.7.6
OS: Pop!_OS 20.04 on VM

I’m experiencing some strange behavior with backward(). Below is some pseudocode that shows what I am doing:

ZERO = sys.float_info.min
ZERO_PT = torch.tensor(ZERO)

def sentence_loss(sent: str):
    vec = torch.tensor(sentence_vec(sent), dtype=torch.float, requires_grad=True) # remove punct, lower case, split on space, prepend "<s>", postpend "</s>" start and stop tokens. Returns tensor of ints of vocab.
    last_idx = min(max_ngram, len(vec)) #max_ngram is an int

    probs = [max(ZERO_PT, pkatz(vec[0:i])) for i in range(2, last_idx + 1)] #pkatz is katz backoff probability and returns a tensor with grad function set.
    for i in range(1, len(vec) - last_idx + 1):
        j = i + last_idx
        probs.append(max(ZERO_PT, pkatz(vec[i:j])))

    probs = torch.stack(probs)
    log_probs = torch.log(probs)
    log_prob = torch.sum(log_probs)
    len_tensor = torch.tensor(len(vec), dtype=float, requires_grad=True)
    final_prob = torch.true_divide(-log_prob, len_tensor)
    return final_prob

class MyModel(MyTextGenModel):
    def __init__(self):
        ...

    def forward(self, input_str: str):
        gen_ids = self.generate(input_str, max_length=25)
        decoded_gen_sample = self.tokenizer.decode(gen_ids, skip_special_tokens=True)
        return decoded_gen_sample


model = MyModel()

for epoch in range(2):
    for i, input_str in enumerate(MyDataLoader, 0):
        output = model(input_str)
        print(output)
        loss = sentence_loss(output)
        loss.backward()

In words, MyModel inherits from a text generation model (namely GPT2 from HuggingFace Transformers). The forward method takes an input string from the training set and generates tokens from GPT2 until max_length is hit then returns the generated portion. That output is given to sentence_loss() to calculate the katz backoff loss with respect to an ngram model I have trained previously.

PROBLEM
If I put a breakpoint at print(output) and one at loss.backward(), the training loop will get through 2 examples just fine. On the third example, the breakpoint at loss.backward() stops the code. Then when I resume the program from there, instead of hitting the breakpoint at print(output), it comes back to loss.backward(). Resuming the program from that point results in the following error:

RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

Does anyone know why this might happen? The error makes sense because the computation graph has been cleared by the time it “recursively” comes back to loss.backward(), but I can’t figure out why it would do that recursion.

THINGS I’VE CHECKED

  • I’ve checked the loss tensor and there isn’t anything that stands out to me as problematic (grad_fn is set, requires_grad=True, etc.).

  • Technically there are no learned parameters given to sentence_loss(). sentence_vec() returns a list of ints over the vocabulary the ngrams model was trained on to represent sent. I tried removing the requires_grad flags in sentence_loss() and pkatz() but setting it to True in vec. That produces the following error on the first example:

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
  • Lastly, I tried setting requires_grad=False in sentence_loss(), pkatz(), and vec. That produces the following error on the first example:
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

If it would be helpful to see the entirety of the code, please send me a PM and I’ll be happy to provide it (it’s a little messy). Thank you in advance for your help!

Hi,

The print(output) and loss.backward() you mention above are the ones in the last lines of your code sample right?

If you don’t set any breakpoint, is the code working fine?

If the structure of your code is:

    for i, input_str in enumerate(MyDataLoader, 0):
        output = model(input_str)
        print(output)
        loss = sentence_loss(output)
        loss.backward()

I don’t see how the last line can be called twice without the print being called.
Or you have other calls to .backward() in your code?

Hi @albanD. Without any breakpoints, the code still returns the same errors. To illustrate the issue further, I setup my code as follows:

for i, input_str in enumerate(MyDataLoader, 0):
        output = model(input_str)
        print(output)
        loss = sentence_loss(output)
        loss.backward()
        print('pytorch is fantastic!')

and set another breakpoint at print('pytorch is fantastic!'). On the first two examples, that breakpoint is hit. On the third example, the breakpoint at print('pytorch is fantastic!') is not hit and loss.backward() gets called twice, producing the error on the second call.

The structure of my code does follow what is above, but I am using the Transformers library to handle the finetuning and GPT2. I’ve simplified it in this post because it’s a bit more complicated. The issue might be related to Transformers so I will definitely check with them.

Is there anything I could print out for loss before and/or after the loss.backward() call that might be useful? I cal also see into it’s attributes when the breakpoints hit.

On the third example, the breakpoint at print('pytorch is fantastic!') is not hit and loss.backward()

Does the print(output) gets called again?
I don’t think this is possible that this one line gets called again :slight_smile: Unless something goes horribly wrong with the python interpreter haha
The autograd most likely gets called from some other place in the code no?

`print(output) does not get called again. On the third example, the breakpoint hits are as follows:

print(output)
loss.backward()
loss.backward()
*exit with error*

I’ll go check with the Transformer folks to see what they think and mention your question about the autograd. It very well could be the case that it’s being called somewhere else in the library. I agree also that I don’t see how loss.backward() could be called twice, but not cause problems on the other examples.