Loss is seen but network weights are not updated and network does not learn

Dear all, i am having issues with training of my feed forward NNs
I have tried my best to resolve the issues but failed so far. so I come to you all for some help.
Let me describe the whole code as concisrely as possible.

Basically, this is an actor critic structure to find optimal control using HDP structure.
There are two networks CriticNN and ActorNN , both are feed forward NNs. Their structures and optimizers are defined as:

class Network(nn.Module):
    def __init__(self,D_in,D_out):
        super().__init__()
        
        # Inputs to hidden layer linear transformation
        self.lin1 = nn.Linear(D_in, 8)
        
        self.lin2=nn.Linear(8,4)
        # Output layer,
        self.output = nn.Linear(4, D_out)
        
        # Define sigmoid activation and softmax output 
        #self.tanh = F.tanh()
        
        
    def forward(self, x):# this is where the data flows in the network, respecting 
                         #sequence of layers in forward method is very important.
        # Pass the input tensor through each of our operations
        
        x = self.lin1(x)
        x = F.relu(x)
        
        x = self.lin2(x)
        x = F.relu(x)
                
        x = self.output(x)
        y = F.tanh(x)
        
        return y


criterion = torch.nn.MSELoss(size_average=False)


optimizer_c = torch.optim.SGD(criticNN.parameters(), 
                            lr=0.001) 
optimizer_a = torch.optim.SGD(actorNN.parameters(), 
                            lr=0.001)
                            

actor,NN , criticNN and model NN are instantiated and
some essential functions are declared next.

criticNN = Network(2, 1)
actorNN=Network(2,1)
model=Network(3,2)
def to_np(x):
    return x.data.cpu().numpy()

def instant_reward(state,control,Q,R):
    
    cc=np.matmul(Q, state)
    rew1=np.matmul(state.T,cc)
    rew2=control*control*R
    res=rew1+rew2
    return res

q11,q12,q21,q22=1,0,0,1 
Q=np.array([[q11,q12],[q21,q22]])
R=0.001

Rest of the problem is described in the comments along with the code to
facilitate the understanding. please see the comments with code

# a pretrained NN called "model" is used here in eval mode 
#this will be used in reward generation later
model.eval()
######## essential parameters for 


train_losses_c = []
train_losses_a = []
valid_losses = []
valid_score = []
epochs_c=[]
epochs_a=[]

#### epochs for actor and critic traninig ##############
epoch_iter_c =range(1, 10)
epoch_iter_a =range(1, 10)

q11,q12,q21,q22=10,0,0,10 
Q=np.array([[q11,q12],[q21,q22]])
R=1

####input single data for training actorNN and criticNN
x0=np.array([[0.1,0.1]])

##### training of critic and actor 
###### iteration range is defined, each iteration...
#critic is trained for several epochs, followed by actorNN

for iteration in range(1,4):
    print ("iteration=",iteration)
    
    #first iteration has "control" initialized and value (cost function) initialized
    if iteration==1:
    
        control=np.array([[0.001]])
        value=0
        print("initial action value chosen",control)
        print("initial value is",value)
    # for other iteraitons, the control value is outputted by...
    #previously trained actorNN (i.e. trained in previous iteratoin in several epochs)
    #similarly, "value" is outputted by previously trained criticNN (ie.trained in previous iteration)
    else:
        #torch no grad is used as actorNN and  criticNN used here are trained in previous iteration 
        with torch.no_grad():

            #actorNN.eval()
            #calculate action from already trained actiorNN
            action=actorNN(torch.from_numpy(x0).float())
            control=to_np(action)
            print("control value outputted by trained actorNN is",control)

            #criticNN.eval()
            #calculate cost function (value) from already trained critic NN
            concat_input_system_model_c=np.concatenate((x0,control),axis=1)
            #prepare tensor to feed into trained system NN 
            concat_input_system_model_c_torch=torch.from_numpy(concat_input_system_model_c)
            #feed in to the system NN
            x1=model(concat_input_system_model_c_torch.float())

            value=criticNN(x1)
            print("value outputted by trained criticNN is",value)
        
        
    #in each iteration critic and actorNN are trained for several epochs one after another. 
    idx_c=0 
    idx_a=0
    tloss_avg_c=0
    tloss_avg_a=0
    
    # training of criticNN for several epochs 
    #the criticNN is put in train mode
    criticNN.train()
    
    for epoch in epoch_iter_c:   
        epochs_c.append(epoch)
        idx_c+=1

        
        #criticNN estimation
        estimation_criticNN=criticNN(torch.from_numpy(x0).float())
        
        #reward is generated, check that it is torch
        #this reward generation requires "control" outputted above by previously trained actorNN
        reward=instant_reward(state=x0.reshape(-1,1),control=control,Q=Q,R=R)
        reward_tensor=torch.from_numpy(reward).float()
        
        #calculate the target
        #target calculation requires the "value" generated above by trained criticNN....
        #which is done outside epoch loop within if-else statement after iteration assignement. 
        # that is why, no grad is activated here as backprop should be done only ....
        #in criticNN that outputs "estimation_criticNN" above.
        with torch.no_grad():
            target = reward_tensor + value # predict y based on x

        
        
        loss_c = criterion(estimation_criticNN,target) # compute loss

        optimizer_c.zero_grad() # clear gradients
        loss_c.backward() # compute gradients
        optimizer_c.step() # apply gradients

        tloss_avg_c += loss_c.item()

        tloss_avg_c /= idx_c
        train_losses_c.append(tloss_avg_c)

        print(" Epoch : %s , Critic train loss: %s " %(epoch,tloss_avg_c))
    
    
    #critic training is done,
    #keep the criticNN tranied weights frozen and train actorNN
    
    ############train actor#################################################3
    
    criticNN.eval()
    
    actorNN.train()

    for epoch in epoch_iter_a:   

        epochs_a.append(epoch)
        idx_a += 1
        
        #estimation of control is outputted by actorNN which must be trained now
        estimation_control=actorNN(torch.from_numpy(x0).float())
        #convert to numpy
        estimation_control_numpy = to_np(estimation_control)
        print("Control value is",estimation_control_numpy)
        
        #reward is generated using estimation control outputted by actorNN just above
        reward_a        = instant_reward(state=x0.reshape(-1,1),control=estimation_control_numpy,Q=Q,R=R)
        reward_a_tensor = torch.from_numpy(reward_a).float()
        #print("reward_a_tensor",reward_a_tensor)
        
        #concatenation of x0 and estimation_control to feed into "model"
        concat_input_system_model_a=np.concatenate((x0,estimation_control_numpy),axis=1)      
        #prepare tensor to feed into trained system NN 
        concat_input_system_model_a_torch = torch.from_numpy(concat_input_system_model_a)
        #feed in to the system NN to obtain x1 from x0 which is 
        x1_a=model(concat_input_system_model_a_torch.float())
        #print("x1_a",x1_a)
        
        
        #calculate the target for actor training
        #the target is calculated usign reward generated just above ...
        #which is(sensitive to estimation by actorNN) and prediction by criticNN trained above...
        #the criticNN here is in eval mode
        
        target_a = reward_a_tensor + criticNN(x1_a) # predict y based on x
        print("target for action network is",target_a)
        #this target value is to be minimised. thus, grad is set true.
        target_a = Variable(target_a, requires_grad = True)
        
        
        #crete zeros like
        zeros_like=torch.zeros_like(target_a)
        #print("zeros like",zeros_like)
        
        #loss is generated using target and zero as target value is to be minimised for actor ...
        # weight update
        loss_a = criterion(target_a,zeros_like) # compute loss
        #print("loss a",loss_a)
        
        #loss_a = Variable(loss_a, requires_grad = True)
        
        optimizer_a.zero_grad() # clear gradients
        loss_a.backward() # compute gradients
        optimizer_a.step() # apply gradients

        tloss_avg_a += loss_a.item()
        #print("loss item",loss_a.item())
        
        tloss_avg_a /= idx_a
        train_losses_a.append(tloss_avg_a)

        print(" Epoch : %d , Actor Train loss: %s " %(epoch,tloss_avg_a))
        print('')

on running it, we can see that loss of critic goes down, indicated by Critic train loss, and the value outputted by critic also changes , idicated by “value outputted by trained criticNN” after iteration 2 included.

Loss of actorNN also goes down in each iteration , as indicated by “actor train loss”
but the value outputted by actorNN does not change , indicated by “control value is”
and hence actorNN does not learn.

this is my main problem. I am thankful for all the help.
here is the what is obtained after running it

Note: the criticNN and actor NN is trained here suing one data sample x0 and different targets . I have tried training in the similar way using batches of x0, but same results.

iiteration= 1
initial action value chosen [[0.001]]
initial value is 0
 Epoch : 1 , Critic train loss: 0.004236207809299231 
 Epoch : 2 , Critic train loss: 0.004186923382803798 
 Epoch : 3 , Critic train loss: 0.002742775638277332 
 Epoch : 4 , Critic train loss: 0.0016725545089381435 
 Epoch : 5 , Critic train loss: 0.0011056517212030786 
 Epoch : 6 , Critic train loss: 0.0008119643504162216 
 Epoch : 7 , Critic train loss: 0.0006415206117607239 
 Epoch : 8 , Critic train loss: 0.0005293509762042151 
 Epoch : 9 , Critic train loss: 0.0004488066545300284 
Control value is [[0.32846534]]
target for action network is tensor([[0.4479]], grad_fn=<AddBackward0>)
 Epoch : 1 , Actor Train loss: 0.20059673488140106 

Control value is [[0.32846534]]
target for action network is tensor([[0.4479]], grad_fn=<AddBackward0>)
 Epoch : 2 , Actor Train loss: 0.20059673488140106 

Control value is [[0.32846534]]
target for action network is tensor([[0.4479]], grad_fn=<AddBackward0>)
 Epoch : 3 , Actor Train loss: 0.1337311565876007 

Control value is [[0.32846534]]
target for action network is tensor([[0.4479]], grad_fn=<AddBackward0>)
 Epoch : 4 , Actor Train loss: 0.08358197286725044 

Control value is [[0.32846534]]
target for action network is tensor([[0.4479]], grad_fn=<AddBackward0>)
 Epoch : 5 , Actor Train loss: 0.0568357415497303 

Control value is [[0.32846534]]
target for action network is tensor([[0.4479]], grad_fn=<AddBackward0>)
 Epoch : 6 , Actor Train loss: 0.04290541273852189 

Control value is [[0.32846534]]
target for action network is tensor([[0.4479]], grad_fn=<AddBackward0>)
 Epoch : 7 , Actor Train loss: 0.03478602108856042 

Control value is [[0.32846534]]
target for action network is tensor([[0.4479]], grad_fn=<AddBackward0>)
 Epoch : 8 , Actor Train loss: 0.029422844496245187 

Control value is [[0.32846534]]
target for action network is tensor([[0.4479]], grad_fn=<AddBackward0>)
 Epoch : 9 , Actor Train loss: 0.025557731041960696 

iteration= 2
control value outputted by trained actorNN is [[0.32846534]]
value outputted by trained criticNN is tensor([[0.1400]])
 Epoch : 1 , Critic train loss: 0.09390095621347427 
 Epoch : 2 , Critic train loss: 0.09281281009316444 
 Epoch : 3 , Critic train loss: 0.06080528721213341 
 Epoch : 4 , Critic train loss: 0.037084988318383694 
 Epoch : 5 , Critic train loss: 0.024520622007548808 
 Epoch : 6 , Critic train loss: 0.018012114086498818 
 Epoch : 7 , Critic train loss: 0.01423531560936854 
 Epoch : 8 , Critic train loss: 0.011750161385584977 
 Epoch : 9 , Critic train loss: 0.009965951214901729 
Control value is [[0.32846534]]
target for action network is tensor([[0.4790]], grad_fn=<AddBackward0>)
 Epoch : 1 , Actor Train loss: 0.22944745421409607 

Control value is [[0.32846534]]
target for action network is tensor([[0.4790]], grad_fn=<AddBackward0>)
 Epoch : 2 , Actor Train loss: 0.22944745421409607 

Control value is [[0.32846534]]
target for action network is tensor([[0.4790]], grad_fn=<AddBackward0>)
 Epoch : 3 , Actor Train loss: 0.15296496947606406 

Control value is [[0.32846534]]
target for action network is tensor([[0.4790]], grad_fn=<AddBackward0>)
 Epoch : 4 , Actor Train loss: 0.09560310592254004 

Control value is [[0.32846534]]
target for action network is tensor([[0.4790]], grad_fn=<AddBackward0>)
 Epoch : 5 , Actor Train loss: 0.06501011202732723 

Control value is [[0.32846534]]
target for action network is tensor([[0.4790]], grad_fn=<AddBackward0>)
 Epoch : 6 , Actor Train loss: 0.04907626104023721 

Control value is [[0.32846534]]
target for action network is tensor([[0.4790]], grad_fn=<AddBackward0>)
 Epoch : 7 , Actor Train loss: 0.03978910217919047 

Control value is [[0.32846534]]
target for action network is tensor([[0.4790]], grad_fn=<AddBackward0>)
 Epoch : 8 , Actor Train loss: 0.033654569549160816 

Control value is [[0.32846534]]
target for action network is tensor([[0.4790]], grad_fn=<AddBackward0>)
 Epoch : 9 , Actor Train loss: 0.029233558195917428 

iteration= 3
control value outputted by trained actorNN is [[0.32846534]]
value outputted by trained criticNN is tensor([[0.1711]])
 Epoch : 1 , Critic train loss: 0.09431567788124084 
 Epoch : 2 , Critic train loss: 0.09324267506599426 
 Epoch : 3 , Critic train loss: 0.06110667188962301 
 Epoch : 4 , Critic train loss: 0.03728598294158777 
 Epoch : 5 , Critic train loss: 0.02466679352025191 
 Epoch : 6 , Critic train loss: 0.018129166194962132 
 Epoch : 7 , Critic train loss: 0.014335107055330087 
 Epoch : 8 , Critic train loss: 0.011838365811485028 
 Epoch : 9 , Critic train loss: 0.010045669610674511 
Control value is [[0.32846534]]
target for action network is tensor([[0.5096]], grad_fn=<AddBackward0>)
 Epoch : 1 , Actor Train loss: 0.2596904933452606 

Control value is [[0.32846534]]
target for action network is tensor([[0.5096]], grad_fn=<AddBackward0>)
 Epoch : 2 , Actor Train loss: 0.2596904933452606 

Control value is [[0.32846534]]
target for action network is tensor([[0.5096]], grad_fn=<AddBackward0>)
 Epoch : 3 , Actor Train loss: 0.17312699556350708 

Control value is [[0.32846534]]
target for action network is tensor([[0.5096]], grad_fn=<AddBackward0>)
 Epoch : 4 , Actor Train loss: 0.10820437222719193 

Control value is [[0.32846534]]
target for action network is tensor([[0.5096]], grad_fn=<AddBackward0>)
 Epoch : 5 , Actor Train loss: 0.0735789731144905 

Control value is [[0.32846534]]
target for action network is tensor([[0.5096]], grad_fn=<AddBackward0>)
 Epoch : 6 , Actor Train loss: 0.05554491107662519 

Control value is [[0.32846534]]
target for action network is tensor([[0.5096]], grad_fn=<AddBackward0>)
 Epoch : 7 , Actor Train loss: 0.04503362920312654 

Control value is [[0.32846534]]
target for action network is tensor([[0.5096]], grad_fn=<AddBackward0>)
 Epoch : 8 , Actor Train loss: 0.038090515318548394 

Control value is [[0.32846534]]
target for action network is tensor([[0.5096]], grad_fn=<AddBackward0>)
 Epoch : 9 , Actor Train loss: 0.03308677874042323 

Hi,

This line looks suspicious: target_a = Variable(target_a, requires_grad = True)
Note that Variable are not a thing anymore so you can remove it.
Also it seems like both inputs to criterion(target_a,zeros_like) are actually leaf Tensors with no history? I think you break the graph earlier. In particular, reward_a is computed with estimation_control_numpy which contains a .data breaking the graph (note that you should not use .data anymore. You can replace it with .detach() here).

Hello thank you very much for your time.
" This line looks suspicious: target_a = Variable(target_a, requires_grad = True)
Note that Variable are not a thing anymore so you can remove it."
I deleted it , but leads to same results.

" I think you break the graph earlier. In particular, reward_a is computed with estimation_control_numpy which contains a .data breaking the graph (note that you should not use .data anymore. You can replace it with .detach() here)."

I think you indicate something very useful.
where can i find .data ? i do not see it explicitly. so how to reslve this problem? sorry , but i cannot imagine using .detach. can you suggest how?

ideally, i want the “estimation_control_numpy” to lead the target_a formation, and backprop within the graph.
thanks.

You use .data here:

def to_np(x):
    return x.data.cpu().numpy()

You can replace it with

def to_np(x):
    return x.detach().cpu().numpy()

Also note that in both case, if you go to numpy, the autograd won’t be able to compute gradients. So no intermediary results should be done using numpy arrays, only Tensor should be used.

1 Like

Thank you AlbanD. that worked.
I do all the intermediatery operations in torch and now I can see the network leanring.
many thanks indeed.

1 Like