Cannot find inplace operation using set_detect_anomaly(True)

I am building an LSTM model for time series prediction. We start with a pandas DataFrame and load that into a TensorDataset and train the model.

I have received this error a few times and have managed to fix it in the past, however this time I just cannot find the in-place operation in question so I would really appreciate a second eye.

Before forming the Tensor dataset I did have to perform some operations to construct the dataset in pandas but I don’t think that should affect autograd, right?

Code snippets:

leads_df = leads_df.resample('6H').sum() # just a resample of df
leads_df_var = leads_df.iloc[:, :10]
from torch.utils.data import TensorDataset, DataLoader
from torch.utils.data.dataset import random_split

def generate_sequences(df, tw):
    data = list()
    L = len(df)
    for i in range(L-tw):
        sequence = df[i:i+tw].values
        target = df[i+tw:i+tw+1].values
        data.append((sequence, target))
    return data

def list_of_tensors_to_dataset(data):
    tensor_x = torch.Tensor([x[0] for x in data])
    tensor_y = torch.Tensor([x[1] for x in data])
    return TensorDataset(tensor_x, tensor_y)
seq_len = 4 
data = generate_sequences(leads_df_var, seq_len)
dataset = list_of_tensors_to_dataset(data)

num_features = dataset[0][0].shape[1]
output_size = dataset[0][0].shape[1]
hidden_size = 128
batch_size = 16

train_len = int(len(dataset)*split)
lens = [train_len, len(dataset)-train_len]
dataloader = DataLoader(dataset, batch_size=batch_size)

Model:

class LeadsPredictor(nn.Module):
    
    def __init__(self, num_features, seq_len, hidden_size, output_size):
        super().__init__()

        self.num_features = num_features
        self.hidden_size = hidden_size
        self.seq_len = seq_len
        self.output_size = output_size
        self.n_layers = 1
        self.lstm = nn.LSTM(self.num_features,
                            self.hidden_size,
                            batch_first=True)

        self.fc = nn.Linear(self.hidden_size*self.seq_len, self.output_size)

        hidden_state = torch.zeros(self.n_layers, batch_size, self.hidden_size)
        cell_state = torch.zeros(self.n_layers, batch_size, self.hidden_size)

        self.hidden = (hidden_state, cell_state)
            
    def forward(self, x):
        x, self.hidden = self.lstm(x, self.hidden)
        x = x.contiguous().view(batch_size,-1)
        x = self.fc(x)
        return x

Train function:

import time
import matplotlib.pyplot as plt
def train(model, dataloader, num_epochs):
    with torch.autograd.set_detect_anomaly(True):
        model.train()
        losses = list()
        ts = time.time()
        for epoch in tqdm(range(num_epochs)):
            epoch_losses = list()
            for idx, (seq, label) in enumerate(dataloader):

                optimizer.zero_grad()
                out = model(seq)
                loss = criterion(out, label.squeeze())
                loss.backward(retain_graph=True)
                optimizer.step()

        te = time.time()
        fig, ax = plt.subplots()
        ax.plot(range(num_epochs), losses)
        plt.show()
        mins = int((te-ts) / 60)
        secs = int((te-ts) % 60)
        print('Training completed in {} minutes, {} seconds.'.format(mins, secs))
        return losses, model

Error Message:
one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [128, 512]], which is output 0 of TBackward, is at version 2; expected version 1 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

I would really reaally appreciate some help with this

Before forming the Tensor dataset I did have to perform some operations to construct the dataset in pandas but I don’t think that should affect autograd, right?

No that won’t

Can you share the traceback reported by the detect anomaly? Where does it point?

RuntimeError                              Traceback (most recent call last)
<ipython-input-244-b5623f61a7a3> in <module>
----> 1 losses, model = train(model, dataloader, 5)

<ipython-input-242-54990647c901> in train(model, dataloader, num_epochs)
     13                 out = model(seq)
     14                 loss = criterion(out, label)
---> 15                 loss.backward(retain_graph=True)
     16                 optimizer.step()
     17 

~\anaconda3\envs\MarketingAnalytics\lib\site-packages\torch\tensor.py in backward(self, gradient, retain_graph, create_graph)
    183                 products. Defaults to ``False``.
    184         """
--> 185         torch.autograd.backward(self, gradient, retain_graph, create_graph)
    186 
    187     def register_hook(self, hook):

~\anaconda3\envs\MarketingAnalytics\lib\site-packages\torch\autograd\__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables)
    123         retain_graph = create_graph
    124 
--> 125     Variable._execution_engine.run_backward(
    126         tensors, grad_tensors, retain_graph, create_graph,
    127         allow_unreachable=True)  # allow_unreachable flag

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [128, 512]], which is output 0 of TBackward, is at version 2; expected version 1 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

@albanD This is the full traceback

Don’t you have a second traceback (as a warning) just above this one?

@albanD No there’s nothing above that. Maybe I’m filtering out the warnings, although I haven’t added a command that filters them. This was my main confusion too as I was expecting a pointer to the operation.

That is very surprising…
You can still check for inplace ops on the result of a transpose op.
Common ones happen in the Linear forward.

Why do you set retain_graph=True in your code btw?

Doing things like this would cause this error:

model = nn.Linear(...)
opt = optim.SGD(model.parameters(), ...)

out = model(inp)
out.backward(retain_graph=True)
opt.step()
out.backward() # Here, the linear weights were modified inplace

@albanD

So I do have a linear layer in my network. Would this correspond to that? What might be the correct way to do it?

I noticed in your snippet you call backward twice, if this is in reference to my retain_graph=True parameter, when I remove it I get the following error:

RuntimeError: Trying to backward through the graph a second time, but the saved intermediate results have already been freed. Specify retain_graph=True when calling backward the first time.

I understand that this is not regular so is there something I should fix there?

Also, thank you so much for your help so far

Ho I just noticed that you update self.hidden in your forward function.
That means that the next iteration will depend on the previous one.

I am not sure how you want to handle this. But you should either .detach() if you want to keep the updated value without having gradients flow back all the way to the previous iteration.
Or if the value of the original hidden state should be fixed to a constant value, then you need to re-initialize it with a Tensor with no history.

3 Likes

Ah I think that worked! Thank you so much!

I have changed the line:

x, self.hidden = self.lstm(x, self.hidden)

to:

x, _ = self.lstm(x, self.hidden)

So that it does not update self.hidden