Backpropagation for custom loss

Hey guys,
for educational purposes I’m trying to write a deep learning actor-critic model for the simplest problem I could find: The multi-armed bandit.

For simplicity’s sake, I’m starting out with one machine that returns 0 or 1 with equal probability. (and I will slowly make the problem more complex once I understand the stuff :wink: )

So I have a model that’s working where the critic correctly converges to a value of 0.5 for the state of playing the game once. But I’m sure I’m doing several unnecessary steps in the training part:

class Net(torch.nn.Module):
def init(self):
super(Net, self).init()
self.l1 = torch.nn.Linear(1, 64)
self.l2 = torch.nn.Linear(64, 64)
self.actor = torch.nn.Linear(64, 3)
self.critic = torch.nn.Linear(64, 1)
self.memory = []
self.GAMMA = 0.8

def forward(self, obs):
    hl1 = self.l1(obs)
    hl1 = F.relu(hl1)
    hl2 = self.l2(hl1)
    hl2 = F.relu(hl2)
    #print('hl2',hl2)
    critic_out = self.critic(hl2)
    actor_out = self.actor(hl2)
    return actor_out, critic_out

def get_pred_withgrad(self, obs):
    return self.forward(obs)

def get_pred_nograd(self, obs):
    with torch.no_grad():
        return self.forward(obs)

def game(self):
    action = 0
    reward = bandit.step(action)
    #memory.append
    self.train(action, reward)

def train(self, action, reward):
    obs = torch.tensor([0], dtype= torch.float)
    _, bad_critic = self.forward(obs)
    bad_critic_value = bad_critic.detach().numpy()[0]
    print('Value estimation: ', bad_critic_value)
    _, better_critic = self.get_pred_nograd(obs) # this would be new_obs if we actually have different states
    delta = reward  - bad_critic_value # + self.GAMMA * better_critic.numpy()[0]  not included for now since we don't have a next state. 
    print('Delta: ', delta)
    better_critic[0] += delta
    
    loss_fn = torch.nn.MSELoss(size_average=False)
    optimizer = torch.optim.SGD(self.parameters(), lr=0.0001)
    optimizer.zero_grad()
    
    loss = loss_fn(bad_critic, better_critic)
    loss.backward()
    optimizer.step()

So basically, my question is this: the delta I calculate IS the loss. How can I use this directly to backpropagate? Because what I’m doing works but is probably ridiculous :smiley:

By the way: I’m aware that the whole state thing is unnecessary for now. But that will make more sense once I expand the problem…

Thanks for your help!
PS: Not sure why the first block is not shown the same way as the rest…