I am getting an Runtime Error. I am quite certain it is not the batch size (after looking at other posts) that’s causing the issue, as the algorithm runs for quite a few iterations. Perhaps there is a memory leak somewhere? My belief is that this happens when I am transferring data to GPU in the compute_td_loss function. I sincerely apologize for posting this entire block of code -
# Here we import all libraries
import numpy as np
#pip install gym[atari,accept-rom-license]
import gym
import matplotlib.pyplot as plt
import os
import torch
import random
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from collections import deque
import torchvision
import torch.nn.functional as F
import sys
env = gym.make("ALE/Pong-v5")
# In[2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# In[5]:
#Hyperparameters
episodes = 20000*10
eps = 1.0
learning_rate = 0.001
tot_rewards = []
tot_loss = []
decay_val = 0.0001
mem_size = 1000000
batch_size = 300
gamma = 0.99
update_target = 100
max_steps = 200
PATH = "./saved_models/pong"
# In[6]:
class NeuralNetwork(nn.Module):
def __init__(self, state_size, action_size):
super(NeuralNetwork, self).__init__()
self.state_size = state_size
self.action_size = action_size
self.conv1 = nn.Conv2d(state_size,6,5)
self.pool = nn.MaxPool2d(2,2)
self.conv2 = nn.Conv2d(6,16,5)
self.fc1 = nn.Linear(29008, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, action_size)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = torch.flatten(x, 1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
# In[7]:
model = NeuralNetwork(env.observation_space.shape[2], env.action_space.n).to(device)
target = NeuralNetwork(env.observation_space.shape[2], env.action_space.n).to(device)
opt = torch.optim.Adam(params=model.parameters(), lr=learning_rate)
replay_buffer = deque(maxlen=mem_size)
# In[8]:
# # Testing code
# state = torch.tensor(env.reset(), dtype=torch.float32).unsqueeze(0)
# state= state.reshape(1, 3, 250, 160)
# print("state = ", state.shape)
# out = model(state)
# print("out = ", out)
# In[9]:
def compute_td_loss(batch_size):
state, next_state, reward, done, action = zip(*random.sample(replay_buffer, batch_size))
# state = torch.stack(list(state), dim=0).reshape(batch_size, -1)
# print("Shape of state = ", torch.stack(list(state), dim=0).squeeze(1).shape)
state = torch.stack(list(state), dim=0).squeeze(1)
state= state.reshape(batch_size, 3, 210, 160).to(device)
# print("next state shape ", torch.from_numpy(np.array(next_state)).reshape(batch_size, 3, 250, 160).shape)
# next_state = torch.from_numpy(np.array(next_state).reshape(batch_size, -1)).type(torch.float32)
next_state = torch.from_numpy(np.array(next_state)).reshape(batch_size, 3, 210, 160).type(torch.float32).to(device)
reward = torch.from_numpy(np.array(reward)).to(device)
done = torch.from_numpy(np.array(done)).long().to(device)
action = torch.from_numpy(np.array(action)).type(torch.int64).to(device)
q_values = model(state)
next_q_values = target(next_state)
q_vals = q_values.gather(dim=-1, index=action.reshape(-1,1))
max_next_q_values = torch.max(next_q_values,-1)[0].detach()
loss = ((reward + gamma*max_next_q_values*(1-done) - q_vals.squeeze())**2).mean()
opt.zero_grad()
loss.backward()
opt.step()
return loss
# In[11]:
if os.path.exists(PATH):
model.load_state_dict(torch.load(PATH))
else:
frame_index = 0
for i in range(episodes):
state = torch.tensor(env.reset(), dtype=torch.float32).unsqueeze(0)
state= state.reshape(1, 3, 210, 160)
done = False
steps = 0
eps_rew = 0
eps_loss = 0
while not done and steps<max_steps:
print("frame_index = ", frame_index, "episode = ", i)
if np.random.uniform(0,1)<eps:
action = env.action_space.sample()
else:
# action = env.action_space.sample()
action = torch.argmax(model(state.to(device))).cpu().detach().numpy()
next_state, reward, done, info = env.step(action)
replay_buffer.append((state, next_state, reward, done, action))
if len(replay_buffer)>batch_size and steps%4==0:
loss = compute_td_loss(batch_size)
eps_loss += loss.cpu().detach().numpy()
eps = eps/(1 + decay_val)
eps_rew += reward
if steps%50==0:
target.load_state_dict(model.state_dict())
if done:
tot_rewards.append(eps_rew)
break
state = torch.tensor(next_state, dtype=torch.float32).unsqueeze(0)
state= state.reshape(1, 3, 210, 160)
steps += 1
frame_index += 1
tot_rewards.append(eps_rew)
tot_loss.append(eps_loss)
if(i%100)==0:
plt.scatter(np.arange(len(tot_rewards)), tot_rewards)
plt.show()
torch.save(model.state_dict(), PATH)
Please let me know if you have any questions. Thank you very much.