Hey guys,
for educational purposes I’m trying to write a deep learning actor-critic model for the simplest problem I could find: The multi-armed bandit.
For simplicity’s sake, I’m starting out with one machine that returns 0 or 1 with equal probability. (and I will slowly make the problem more complex once I understand the stuff )
So I have a model that’s working where the critic correctly converges to a value of 0.5 for the state of playing the game once. But I’m sure I’m doing several unnecessary steps in the training part:
class Net(torch.nn.Module):
def init(self):
super(Net, self).init()
self.l1 = torch.nn.Linear(1, 64)
self.l2 = torch.nn.Linear(64, 64)
self.actor = torch.nn.Linear(64, 3)
self.critic = torch.nn.Linear(64, 1)
self.memory =
self.GAMMA = 0.8def forward(self, obs): hl1 = self.l1(obs) hl1 = F.relu(hl1) hl2 = self.l2(hl1) hl2 = F.relu(hl2) #print('hl2',hl2) critic_out = self.critic(hl2) actor_out = self.actor(hl2) return actor_out, critic_out def get_pred_withgrad(self, obs): return self.forward(obs) def get_pred_nograd(self, obs): with torch.no_grad(): return self.forward(obs) def game(self): action = 0 reward = bandit.step(action) #memory.append self.train(action, reward) def train(self, action, reward): obs = torch.tensor([0], dtype= torch.float) _, bad_critic = self.forward(obs) bad_critic_value = bad_critic.detach().numpy()[0] print('Value estimation: ', bad_critic_value) _, better_critic = self.get_pred_nograd(obs) # this would be new_obs if we actually have different states delta = reward - bad_critic_value # + self.GAMMA * better_critic.numpy()[0] not included for now since we don't have a next state. print('Delta: ', delta) better_critic[0] += delta loss_fn = torch.nn.MSELoss(size_average=False) optimizer = torch.optim.SGD(self.parameters(), lr=0.0001) optimizer.zero_grad() loss = loss_fn(bad_critic, better_critic) loss.backward() optimizer.step()
So basically, my question is this: the delta I calculate IS the loss. How can I use this directly to backpropagate? Because what I’m doing works but is probably ridiculous
By the way: I’m aware that the whole state thing is unnecessary for now. But that will make more sense once I expand the problem…
Thanks for your help!
PS: Not sure why the first block is not shown the same way as the rest…