Trying to use an output of NN from previous episode

I’m trying to implement an experience replay buffer that uses state information from previous episodes but when I update gradients, I can’t use outputs from previous episodes. It gives an error. Is there a way that I can calculate gradients for previous episodes? I added a simple example that gives the same error.

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation:

import torch.nn as nn
import torch

class Model(nn.Module):
    def __init__(self,in_state,out_feats):
        super(Model, self).__init__()
        self.apply_1 = nn.ModuleList([nn.Linear(4*in_state,32), 
                                      nn.ReLU(),
					                  nn.Linear(32,out_feats),
 					                  nn.Softmax(dim=1)])
                                    
    def forward(self,inp):
        hidden = inp
        for layer in self.apply_1:
            hidden = layer(hidden)
        return hidden
  
    
mymodel = Model(3,12)
inp1 = torch.rand(1,12)
inp2 = torch.rand(1,12)
out1 = mymodel(inp1) 
out2 = mymodel(inp2)  ------------- Try to use out2 after optimizer.step() for gradient computation-------------------------

optimizer = torch.optim.Adam(mymodel.parameters(),lr=0.1)

loss = torch.log(out1[0][0])*-10
optimizer.zero_grad()
loss.backward(retain_graph=True)
optimizer.step()
loss = torch.log(out2[0][0])*-10
optimizer.zero_grad()
loss.backward(retain_graph=True) -----------------Gives error at this stage----------------
optimizer.step()

If you want to use 2 losses on the same model, try this:

optimizer.zero_grad()
loss1 = torch.log(out1[0][0])*-10
loss2 = torch.log(out2[0][0])*-10
loss1.backward(retain_graph=True)
loss2.backward() 
optimizer.step()

also, in the example you have there is :
nn.Linear(32,out_feats) ___ ]), <------ probably a typo
nn.Softmax(dim=1)___ ]) <------ cause it should include also the last layer

My point is I want to use the loss coming from out2 after first episode is finished (after my network weights are updated). As it contains weights before the update, it gives error if I try to use it after the update.

Yes, it is a typo. I fixed it.

Ok, i hope i understand what you try to do (sounds like RL btw)
if not, ill try again :slight_smile:
Doing the following will work, does this make sense to your scenario?

loss = torch.log(out1[0][0])*-10
optimizer.zero_grad()
loss.backward(retain_graph=True)
optimizer.step()
out2 = mymodel(inp2)    #<--- move to here, after weights were updated from inp1
loss = torch.log(out2[0][0])*-10
optimizer.zero_grad()
loss.backward()         #<--- do you need retain_graph here ?
optimizer.step()

Roy.

BTW 1:
If this is indeed RL and your dealing with DDQN like scenario, you might need to have 2 separate networks and copy the weights from source to target

BTW 2:
Using Sequential in your example might be more suitable

class Model(nn.Module):
    def __init__(self,in_state,out_feats):
        super(Model, self).__init__()
        self.apply_1 = nn.Sequential(nn.Linear(4*in_state,32), 
                                      nn.ReLU(),
					                  nn.Linear(32,out_feats),
 					                  nn.Softmax(dim=1))
                                    
    def forward(self,inp):
        return self.apply_1(inp)

Roy.

Your code will work with CloningLinear (again, hope i understand what you try to do)

import torch.nn as nn
import torch.nn.functional as F
import torch

class CloningLinear(nn.Linear):
    def forward(self, inputs):
        return F.linear(inputs, self.weight.clone(), self.bias.clone())

class Model(nn.Module):
    def __init__(self,in_state,out_feats):
        super(Model, self).__init__()
        self.apply_1 = nn.ModuleList([CloningLinear(4*in_state,32), 
                                      nn.ReLU(),
					                  CloningLinear(32,out_feats),
 					                  nn.Softmax(dim=1)])
                                    
    def forward(self,inp):
        hidden = inp
        for layer in self.apply_1:
            hidden = layer(hidden)
        return hidden
  
    
mymodel = Model(3,12)
inp1 = torch.rand(1,12)
inp2 = torch.rand(1,12)
out1 = mymodel(inp1) 
out2 = mymodel(inp2)

optimizer = torch.optim.Adam(mymodel.parameters(),lr=0.1)

loss = torch.log(out1[0][0])*-10
optimizer.zero_grad()
loss.backward(retain_graph=True)
optimizer.step()
loss = torch.log(out2[0][0])*-10
optimizer.zero_grad()
loss.backward(retain_graph=True) #<--- still think you should reconsider this retain_graph
optimizer.step()

Yes, you understood correct. CloningLinear is the thing that I was looking for. My buffer is collecting outputs from different episodes that depends on different weight values as they are from different episodes. This approach works for me. Thank you!

1 Like