When I try to normalize a relu function input, encountered A RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph

I am trying to take the max of the previous layer and use it to normalize the activation of next layer in a model, as you see in a sample of code below :
here is the forward in the model where I record the max of each layer in a list ( thresh_list )

def forward(self, input, epoch):
        x = self.conv1(input)
        x = self.relu(x,1)
        self.thresh_list[0] = max(self.thresh_list[0], x.max())  # here to get the max activation 
        x = self.conv_dropout(x)
        x = self.conv2(x)
        x = self.relu(x, self.thresh_list[0])
        self.thresh_list[1] = max(self.thresh_list[1], x.max())
        x = self.pool1(x)
        x = self.conv_dropout(x)
        x = self.conv3(x)
        x = self.relu(x, self.thresh_list[1] )
        self.thresh_list[2] = max(self.thresh_list[2], x.max())

The Relue function I call is a custom function as below :

self.relu = th_norm_ReLU(True)

and the norm_ReLU class is as below :

import torch
import torch.nn as nn
import torch.nn.functional as F
from argument_settings import *

device = "cuda" if torch.cuda.is_available() else "cpu"

class thReLU(torch.autograd.Function):
    
    @staticmethod
    def forward(ctx, input):
      ctx.save_for_backward(input)
      zero_tensor = torch.zeros(input.size()).to(device)
      output = torch.maximum(input , zero_tensor)
      return output 

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        diff = ( input >= 0 )
        return grad_input * diff.float() 

act_fun = thReLU.apply


class th_norm_ReLU(nn.Module):
    def __init__(self, modify):
        super(th_norm_ReLU, self).__init__()
        self.norm = act_fun
       

    def forward(self, input, prev_layer_max):
      output = input * (prev_layer_max / input.max())
      norm_output = self.norm (output)
      return norm_output

if __name__ == '__main__':
    xrelu = th_norm_ReLU(True)
    z = torch.rand(3, 3) * 4
    print(z)
    y = xrelu(z, 1)
    print(y)

Anyone please can help me to solve this problem, actually when I pass fixed value instead of the max of previous layer, it is working well with no error. But when I pass the max of previous layer I encounter this error. please, help me to solve this problem.

Simply, I want to normalize the relu input , using the max of previous layer, but I encounter this error if use this code below and pass the argument prev_layer_max :

class th_norm_ReLU(nn.Module):
    def __init__(self, modify):
        super(th_norm_ReLU, self).__init__()
        self.therelu = F.relu
       

    def forward(self, input, prev_layer_max):
      output = input * (prev_layer_max / input.max())
      norm_output = self.therelu (output)
      
      return norm_output

if I just use constant value instead of prev_layer_max as this code below make it equal to 1 it works normally without any error :

def forward(self, input, prev_layer_max = 1):
      output = input * (1 / input.max())
      norm_output = self.therelu (output)

please, help me with this problem

do you want to backprop through the max? If not then maybe do

self.thresh_list[0] = max(self.thresh_list[0], x.max()).detach()

to avoid backprop through the max…

Thank you so much for your reply, yeah , it works normally now. could you please explain this issue to me and why we need to use detech(), I would really appreciate it.

.detach() means that the variable is simply treated as a constant by the autodiff, i.e. there’s no differentiation via that variable (by using the chain rule). Please note that it could screw up your training if in fact you do need to backprop through it…

Got it , thank you so much for your reply

1 Like

I am wondering, why when I use the detach() it solves my problem , but the training is taking a very longer time. How can I solve this issue, please ?

OK then maybe you indeed need to backprop through the max? The problem could be the attribute assignment self.thresh_list[0] = ..., which is then differentiated. Maybe better to return a thresh_list variable and then use that in the loss. Not quite sure though.