Trying to backward through the graph a second time, backward() on multi agents

Hello,
I am facing an issue when i use backward() on a loss accumulated across multiple models i get a runtime error: Trying to backward through the graph a second time.

RuntimeError: Trying to backward through the graph a second time (or directly access saved variables after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved variables after calling backward.

Anomaly detection

0%| | 0/16896 [00:00<?, ?it/s][W python_anomaly_mode.cpp:104] Warning: Error detected in ReluBackward0. Traceback of forward call that caused the error:
File “/opt/conda/envs/myenv/lib/python3.9/runpy.py”, line 197, in _run_module_as_main
return _run_code(code, main_globals, None,
File “/opt/conda/envs/myenv/lib/python3.9/runpy.py”, line 87, in _run_code
exec(code, run_globals)
File “/opt/conda/envs/myenv/lib/python3.9/site-packages/ipykernel_launcher.py”, line 17, in
app.launch_new_instance()
File “/opt/conda/envs/myenv/lib/python3.9/site-packages/traitlets/config/application.py”, line 992, in launch_instance
app.start()
File “/opt/conda/envs/myenv/lib/python3.9/site-packages/ipykernel/kernelapp.py”, line 711, in start
self.io_loop.start()
File “/opt/conda/envs/myenv/lib/python3.9/site-packages/tornado/platform/asyncio.py”, line 195, in start
self.asyncio_loop.run_forever()
0%| | 1/16896 [00:00<44:22, 6.34it/s]ncio/base_events.py", line 601, in run_forever
self._run_once()
File “/opt/conda/envs/myenv/lib/python3.9/asyncio/base_events.py”, line 1905, in _run_once
handle._run()
File “/opt/conda/envs/myenv/lib/python3.9/asyncio/events.py”, line 80, in _run
self._context.run(self._callback, *self._args)
File “/opt/conda/envs/myenv/lib/python3.9/site-packages/ipykernel/kernelbase.py”, line 510, in dispatch_queue
await self.process_one()
File “/opt/conda/envs/myenv/lib/python3.9/site-packages/ipykernel/kernelbase.py”, line 499, in process_one
await dispatch(*args)
File “/opt/conda/envs/myenv/lib/python3.9/site-packages/ipykernel/kernelbase.py”, line 406, in dispatch_shell
await result
File “/opt/conda/envs/myenv/lib/python3.9/site-packages/ipykernel/kernelbase.py”, line 729, in execute_request
reply_content = await reply_content
File “/opt/conda/envs/myenv/lib/python3.9/site-packages/ipykernel/ipkernel.py”, line 411, in do_execute
res = shell.run_cell(
File “/opt/conda/envs/myenv/lib/python3.9/site-packages/ipykernel/zmqshell.py”, line 531, in run_cell
return super().run_cell(*args, **kwargs)
File “/opt/conda/envs/myenv/lib/python3.9/site-packages/IPython/core/interactiveshell.py”, line 2945, in run_cell
result = self._run_cell(
File “/opt/conda/envs/myenv/lib/python3.9/site-packages/IPython/core/interactiveshell.py”, line 3000, in _run_cell
return runner(coro)
File “/opt/conda/envs/myenv/lib/python3.9/site-packages/IPython/core/async_helpers.py”, line 129, in pseudo_sync_runner
coro.send(None)
File “/opt/conda/envs/myenv/lib/python3.9/site-packages/IPython/core/interactiveshell.py”, line 3203, in run_cell_async
has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
File “/opt/conda/envs/myenv/lib/python3.9/site-packages/IPython/core/interactiveshell.py”, line 3382, in run_ast_nodes
if await self.run_code(code, result, async
=asy):
File “/opt/conda/envs/myenv/lib/python3.9/site-packages/IPython/core/interactiveshell.py”, line 3442, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File “/tmp/ipykernel_58338/3069656723.py”, line 2, in
dsm.train(training_set, lr= 0.001)
File “/tmp/ipykernel_58338/4067010658.py”, line 68, in train
embedding = self.embedding_function(current_segment, self.embedding)
File “/opt/conda/envs/myenv/lib/python3.9/site-packages/torch/nn/modules/module.py”, line 1051, in _call_impl
return forward_call(*input, **kwargs)
File “/tmp/ipykernel_58338/4200841065.py”, line 15, in forward
activated_output = self.relu(final_embedding)
File “/opt/conda/envs/myenv/lib/python3.9/site-packages/torch/nn/modules/module.py”, line 1051, in _call_impl
return forward_call(*input, **kwargs)
File “/opt/conda/envs/myenv/lib/python3.9/site-packages/torch/nn/modules/activation.py”, line 102, in forward
return F.relu(input, inplace=self.inplace)
File “/opt/conda/envs/myenv/lib/python3.9/site-packages/torch/nn/functional.py”, line 1298, in relu
result = torch.relu(input)
(function _print_stack)

adding retain_graph = True introduces another error

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [512, 128]], which is output 0 of TBackward, is at version 2; expected version 1 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

Anomaly detection:

0%| | 1/16896 [00:00<48:34, 5.80it/s][W python_anomaly_mode.cpp:104] Warning: Error detected in MmBackward. Traceback of forward call that caused the error:
File “/opt/conda/envs/myenv/lib/python3.9/runpy.py”, line 197, in _run_module_as_main
return _run_code(code, main_globals, None,
File “/opt/conda/envs/myenv/lib/python3.9/runpy.py”, line 87, in _run_code
exec(code, run_globals)
File “/opt/conda/envs/myenv/lib/python3.9/site-packages/ipykernel_launcher.py”, line 17, in
app.launch_new_instance()
File “/opt/conda/envs/myenv/lib/python3.9/site-packages/traitlets/config/application.py”, line 992, in launch_instance
app.start()
File “/opt/conda/envs/myenv/lib/python3.9/site-packages/ipykernel/kernelapp.py”, line 711, in start
self.io_loop.start()
File “/opt/conda/envs/myenv/lib/python3.9/site-packages/tornado/platform/asyncio.py”, line 195, in start
self.asyncio_loop.run_forever()
File “/opt/conda/envs/myenv/lib/python3.9/asyncio/base_events.py”, line 601, in run_forever
self._run_once()
File “/opt/conda/envs/myenv/lib/python3.9/asyncio/base_events.py”, line 1905, in _run_once
handle._run()
File “/opt/conda/envs/myenv/lib/python3.9/asyncio/events.py”, line 80, in _run
self._context.run(self._callback, *self._args)
File “/opt/conda/envs/myenv/lib/python3.9/site-packages/ipykernel/kernelbase.py”, line 510, in dispatch_queue
await self.process_one()
File “/opt/conda/envs/myenv/lib/python3.9/site-packages/ipykernel/kernelbase.py”, line 499, in process_one
await dispatch(*args)
File “/opt/conda/envs/myenv/lib/python3.9/site-packages/ipykernel/kernelbase.py”, line 406, in dispatch_shell
await result
File “/opt/conda/envs/myenv/lib/python3.9/site-packages/ipykernel/kernelbase.py”, line 729, in execute_request
reply_content = await reply_content
File “/opt/conda/envs/myenv/lib/python3.9/site-packages/ipykernel/ipkernel.py”, line 411, in do_execute
res = shell.run_cell(
File “/opt/conda/envs/myenv/lib/python3.9/site-packages/ipykernel/zmqshell.py”, line 531, in run_cell
return super().run_cell(*args, **kwargs)
File “/opt/conda/envs/myenv/lib/python3.9/site-packages/IPython/core/interactiveshell.py”, line 2945, in run_cell
result = self._run_cell(
File “/opt/conda/envs/myenv/lib/python3.9/site-packages/IPython/core/interactiveshell.py”, line 3000, in _run_cell
return runner(coro)
File “/opt/conda/envs/myenv/lib/python3.9/site-packages/IPython/core/async_helpers.py”, line 129, in pseudo_sync_runner
coro.send(None)
File “/opt/conda/envs/myenv/lib/python3.9/site-packages/IPython/core/interactiveshell.py”, line 3203, in run_cell_async
has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
File “/opt/conda/envs/myenv/lib/python3.9/site-packages/IPython/core/interactiveshell.py”, line 3382, in run_ast_nodes
if await self.run_code(code, result, async
=asy):
File “/opt/conda/envs/myenv/lib/python3.9/site-packages/IPython/core/interactiveshell.py”, line 3442, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File “/tmp/ipykernel_58338/3069656723.py”, line 2, in
dsm.train(training_set, lr= 0.001)
File “/tmp/ipykernel_58338/1605679448.py”, line 68, in train
embedding = self.embedding_function(current_segment, self.embedding)
File “/opt/conda/envs/myenv/lib/python3.9/site-packages/torch/nn/modules/module.py”, line 1051, in _call_impl
return forward_call(*input, **kwargs)
File “/tmp/ipykernel_58338/4200841065.py”, line 14, in forward
final_embedding = self.output_layer(combined_output)
File “/opt/conda/envs/myenv/lib/python3.9/site-packages/torch/nn/modules/module.py”, line 1051, in _call_impl
return forward_call(*input, **kwargs)
File “/opt/conda/envs/myenv/lib/python3.9/site-packages/torch/nn/modules/linear.py”, line 96, in forward
return F.linear(input, self.weight, self.bias)
File “/opt/conda/envs/myenv/lib/python3.9/site-packages/torch/nn/functional.py”, line 1847, in linear
return torch._C._nn.linear(input, weight, bias)
(function _print_stack)

class embedding_function(nn.Module):
    def __init__(self, embedding_input_dim):
        super(embedding_function, self).__init__()
    
        self.embedding_layer1 = nn.Linear(embedding_input_dim, 2*embedding_input_dim) 
        self.embedding_layer2 = nn.Linear(8, 2*embedding_input_dim)
        self.relu = nn.ReLU(inplace=False)
        self.output_layer = nn.Linear(4*embedding_input_dim, embedding_input_dim)

    def forward(self, input_tensor, input_embedding):
        generated_embedding = self.embedding_layer1(input_embedding)
        generated_input_presentation = self.embedding_layer2(input_tensor)
        combined_output = torch.cat((generated_embedding, generated_input_presentation), dim=0)
        final_embedding = self.output_layer(combined_output)
        activated_output = self.relu(final_embedding)

        return activated_output


class Agent(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Agent, self).__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = x.unsqueeze(0).unsqueeze(0)
        _, (hidden, _) = self.lstm(x)
        output1 = self.fc(hidden.squeeze(0).squeeze(0))
        output = F.softmax(output1, dim=0)
        return output

This is the trainable module

class DSM:
    def __init__(self, embedding_dim):
        self.states = {0: [0, 1, 4],
                       1: [2],
                       2: [2, 3],
                       3: [0, 4]}
        
        self.current_state= 0
        self.cursor = 0
        self.device = torch.device('cuda:5' if torch.cuda.is_available() else 'cpu')

        self.embedding_size= embedding_dim
        self.agent_0 = Agent(input_size=self.embedding_size , hidden_size= 256, output_size= 5).to(self.device)
        self.agent_1 = Agent(input_size=self.embedding_size , hidden_size= 256, output_size= 5).to(self.device)
        self.agent_2 = Agent(input_size=self.embedding_size , hidden_size= 256, output_size= 5).to(self.device)
        self.agent_3 = Agent(input_size=self.embedding_size , hidden_size= 256, output_size= 5).to(self.device)
        self.agent_4 = Agent(input_size=self.embedding_size , hidden_size= 256, output_size= 5).to(self.device)

        self.embedding_function = embedding_function(embedding_input_dim=self.embedding_size)

        self.embedding_function.to(self.device)
        self.parameters = (
                            list(self.embedding_function.parameters()) +
                            list(self.agent_0.parameters()) +
                            list(self.agent_1.parameters()) +
                            list(self.agent_2.parameters()) +
                            list(self.agent_3.parameters()) +
                            list(self.agent_4.parameters())
                        )   
        self.optimizer = None
        self.criterion = nn.CrossEntropyLoss()
        self.EOF = False
       
    def train(self, dataset, lr, num_epochs = None):
        self.embedding_function.train()
        self.agent_0.train()
        self.agent_1.train()
        self.agent_2.train() 
        self.agent_3.train()
        self.agent_4.train()
        

        self.optimizer= optim.Adam(self.parameters, lr=lr)
        self.agents = [self.agent_0, self.agent_1, self.agent_2, self.agent_3, self.agent_4]
        loader = DataLoader(dataset, batch_size= 1, shuffle= False)
        losses=[]
        for file_data, file_labels in loader:
            file_data = torch.tensor(file_data, dtype=torch.float32)
            file_labels = torch.tensor(file_labels, dtype=torch.long)

            file_data= file_data.to(self.device)
            file_labels= file_labels.to(self.device)
            
            running_loss = 0
            action =torch.tensor([1.,0.,0.,0.]).to(self.device)

            self.embedding = torch.rand(self.embedding_size).to(self.device)
            self.cursor = 0
            for _ in tqdm(range(len(file_data))):
                self.optimizer.zero_grad()
                #Data and Labels
                current_segment = file_data[self.cursor]
                current_label = file_labels[self.cursor]

                #Generating the memory embedding
                embedding = self.embedding_function(current_segment, self.embedding)
                self.embedding = embedding.to(self.device)
                
                #Using the generated embedding and the adequate agent to predict the current state
                current_state = torch.argmax(action.cpu()).item()          
                action = self.agents[current_state](self.embedding)

                #Loss calculation and backpropagation through the embedding function and the agents
                loss = self.criterion(input=action.unsqueeze(0), target=current_label.unsqueeze(0)).clone()
                loss.backward()
                running_loss +=  loss
                
                self.optimizer.step()
                self.cursor = self.cursor +  1
            print(f"File total loss = {running_loss}  | Average file loss = {running_loss/len(file_data)}")

it appears the issue is coming from the embedding function, and running the training while detaching the results of the embedding function seems to work.
However, I am not exactly sure what it means in terms of updating the model’s parameter when training while detached, and I do need the embedding function to train, so I don’t think that’s the solution I’m looking for.
I have been stuck on this issue for a couple of days now, I’ve done some research, tried to clone() the tensors, and also tried to set up the inputs argument of the backward() but nothing works so far.

Hi @kerr,

Try replacing this with,

running_loss  += loss.item()

You’re accumulating your graph, which might lead to the in-place error.

@ AlphaBetaGamma96 I am applying backward on the loss, not the running loss

loss.backward(retain_graph = True)

Yes, you are. But you’re accumulating your gradients (via an inplace operation) from previous epochs in the running_loss object (because you’re not detaching your graph, you’re accumulating it.). See if it resolves the issue.

Did you try my suggestion?

yes i apologize, i did try your solution, but i am getting the same error.

So after looking through the stacktrace I think the error is due to the embedding_function (which contains the ReLU function. However, because it comes from the Backward (i.e. when you compute the derivatives) it usually comes from a function called after it.

Looking through the DSM class, you do,

Could you try replacing it with,

self.embedding_function = embedding_function(embedding_input_dim=self.embedding_size)

self.embedding_function = self.embedding_function.to(self.device) # <<< change is here

It might not work, but give it a go.

didn’t work, same issue

Hi Kerr!

I don’t see any obvious place in your code where the “second time” backward
might be occurring. Could you post fully-self-contained (use hard-coded or
random data, as needed), runnable script that reproduces your issue. (Please
simplify the script as much as practical while still capturing your issue.) Please
also post the output you get when you run the script.

My guess is that the second-time backward is your real error, and that
retain_graph = True isn’t a valid fix. (optimizer.step() counts as
an inplace operation that can trigger an inplace-modification error on
your retained graph.)

This part makes me a little suspicious because you are choosing on the fly
which Agent to use for the forward / backward pass, but it looks like it should
be legitimate because the torch.argmax() should “break the computation
graph” and prevent any backpropagation though the previous action that
presumably depended on a different Agent.

Best.

K. Frank