Hello! I’ve been following bentrevett’s repo on RL and got a crazy idea.
Having class Agent(nn.Module) implement a method which will do sampling and making log_prob for reinforcement learning. But it ruined everything. It ruins only if I use it for training(where I need gradients), in other places it does’t change anything.
def sample_action(self, state, stoch=True):
self.__action_pred, value_pred = self.forward(state)
self.__action_prob = F.softmax(self.__action_pred, dim = -1)
dist = distributions.Categorical(self.__action_prob)
if stoch:
action = dist.sample()
else:
action = torch.argmax(self.__action_prob, dim = -1)
return action, value_pred, dist
here is the function that ruins learning.
More, if I even return
self.__action_pred, value_pred = self.forward(state)
learning stops(
You can find problem example here
Dont you know what is this?