I have the following dqn model:
class DQN(nn.Module):
def __init__(self, input_shape, n_actions):
super(DQN, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4),
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=4, stride=2),
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=3, stride=1),
nn.ReLU()
)
conv_out_size = self._get_conv_out(input_shape)
self.fc = nn.Sequential(
nn.Linear(conv_out_size, 512),
nn.ReLU(),
nn.Linear(512, n_actions)
)
def _get_conv_out(self, shape):
o = self.conv(Variable(torch.zeros(1, *shape)))
return int(np.prod(o.size()))
def forward(self, x):
fx = x.float() / 256.0
conv_out = self.conv(fx).view(fx.size()[0], -1)
return self.fc(conv_out)
I trained it for Pong and got more than 18.0 mean reward for the last 100 games.
It was saved:
torch.save(net.state_dict(), 'pong_model.pt')
But when I try to load it and play some games I get wrong results.
net = models.DQN(env.observation_space.shape, env.action_space.n)
net.load_state_dict(torch.load('pong_model.pt', map_location=lambda storage, loc: storage))
for i in range(20):
state = env.reset()
while True:
env.render()
action = agent(state, net)
next_state, _, done, _ = env.step(action)
if done:
break
state = next_state
It can’t win even once. What could be wrong?