I have a script that takes about 20 mins to run on my machine when running a single seed at a time. However, when I run several (10) seeds at once, the code slows down massively, taking around 3h to complete. I would expect running several seeds at once to slow down a bit but not so drastically, especially given that each seed uses around 950MB of GPU memory and the machine I am running on has 4 GPU’s each with 40GB of memory, and I split the jobs evenly across the GPU’s (so a max of 3 seeds on one GPU). I’ve profiled the code and found that what starts to become a lot slower is the to.(device)
method. Unfortunately I cannot store all of the data on the GPU since it is an RL problem and so there is no static dataset since we collect new data online. Is there a reason why this becomes so much slower with more seeds running? When profiling the code, the .to(device)
method goes from using 12107ms to 657858ms.
The version is 1.12.1. Here is a reproducible script that does essentially what the full script does, just cleaned up a bit:
import math
import numpy as np
import torch
import torch.nn as nn
class MLPResidualLayer(nn.Module):
def __init__(self, dim):
super(MLPResidualLayer, self).__init__()
self.fc1 = nn.Linear(dim, dim)
self.fc2 = nn.Linear(dim, dim)
def forward(self, x):
residual = x
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
return residual + x
class Network(nn.Module):
def __init__(self, state_dim, hidden_dim, num_actions, num_heads):
super(Network, self).__init__()
self.input_layer = nn.Linear(state_dim, hidden_dim)
self.resnet = MLPResidualLayer(hidden_dim)
self.layer_norm = nn.LayerNorm(hidden_dim)
self.output_heads = VectorizedLinear(hidden_dim, num_actions, num_heads)
self.num_heads = num_heads
def forward(self, x):
x = self.input_layer.forward(x)
x = self.layer_norm(self.resnet.forward(x))
x = x.unsqueeze(dim=0).repeat(self.num_heads, 1, 1)
vals = self.output_heads.forward(x).transpose(0, 1)
return vals
class VectorizedLinear(nn.Module):
def __init__(self, in_features, out_features, ensemble_size):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.ensemble_size = ensemble_size
self.weight = nn.Parameter(torch.empty(ensemble_size, in_features, out_features))
self.bias = nn.Parameter(torch.empty(ensemble_size, 1, out_features))
self.reset_parameters()
def reset_parameters(self):
stdv = 1. / math.sqrt(self.weight.size(1))
self.weight.data.uniform_(-stdv, stdv)
if self.bias is not None:
self.bias.data.uniform_(-stdv, stdv)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# input: [ensemble_size, batch_size, input_size]
# weight: [ensemble_size, input_size, out_size]
# out: [ensemble_size, batch_size, out_size]
return x @ self.weight + self.bias
class ReplayBuffer:
def __init__(self, capacity, state_dim, num_heads, action_dim, batch_size=128):
self.capacity = capacity
self.batch_size = batch_size
self.states = torch.randn(size=(capacity, state_dim), dtype=torch.float)
self.actions = torch.randint(low=0, high=action_dim, size=(capacity, num_heads), dtype=torch.long)
self.rewards = torch.randn(size=(capacity, 1), dtype=torch.float)
self.next_states = torch.randn(size=(capacity, state_dim), dtype=torch.float)
self.dones = torch.randint(low=0, high=2, size=(capacity, 1), dtype=torch.long)
self.state_dim = state_dim
self.num_heads = num_heads
def sample(self):
idx = np.random.randint(low=0, high=self.capacity, size=self.batch_size) # when buffer large the probability of sampling a transition more than once -> 0
return self.states[idx], self.actions[idx], self.rewards[idx], self.next_states[idx], self.dones[idx]
def __len__(self):
return self.capacity
device = 'cuda' if torch.cuda.is_available() else 'cpu'
state_dim = 10
action_dim = 3
num_heads = 6
batch_size = 256
net = Network(state_dim, 512, action_dim, num_heads).to(device)
optimiser = torch.optim.Adam(net.parameters())
loss_fn = nn.HuberLoss()
buffer = ReplayBuffer(100000, state_dim, num_heads, action_dim, batch_size)
for update in range(100000):
if (update + 1) % 1000 == 0:
print(f"Update {update + 1}")
states, actions, rewards, next_states, dones = buffer.sample()
states = states.to(device)
actions = actions.to(device)
rewards = rewards.to(device)
next_states = next_states.to(device)
dones = dones.to(device)
vals = net.forward(states).gather(2, actions.unsqueeze(dim=-1)).squeeze(dim=-1).mean(dim=1)
with torch.no_grad():
targets = net.forward(next_states).max(dim=-1)[0].mean(dim=-1, keepdim=True)
targets = rewards + 0.99 * (1 - dones) * targets
loss = loss_fn(vals, targets.flatten())
optimiser.zero_grad()
loss.backward()
optimiser.step()