Sorry, this was a false alarm. Thank you for pointing me towards doing proper profiling!
In the end, it turns out that the GPU is being used for both the forwards and backwards passes, as expected, and the piece that was using CPU was my call to max()
, because it returns a Python tuple:
next_model_states = torch.max(next_state_predictions, dim=1, keepdim=True)[0]
I think I should be able to replace this with amax()
:
next_model_states = next_state_predictions.amax(dim=1, keepdim=True)
In case it’s helpful to anyone, here’s the full code that I used to test this:
import torch
import torch.nn as nn
import torch.optim as optim
from collections import deque
import torch.profiler as profiler
class Linear_QNet(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super().__init__()
self.linear1 = nn.Linear(input_size, hidden_size)
self.linear2 = nn.Linear(hidden_size + input_size, hidden_size)
self.linear3 = nn.Linear(hidden_size + input_size, hidden_size)
self.linear7 = nn.Linear(hidden_size, output_size)
self.relu = nn.ReLU(inplace=True)
def forward(self, original):
with profiler.record_function("FORWARD PASS"):
x = self.linear1(original)
x = self.relu(x)
x = self.linear2(torch.cat((x, original), 1))
x = self.relu(x)
x = self.linear3(torch.cat((x, original), 1))
x = self.relu(x)
x = self.linear7(x)
return x
if __name__ == '__main__':
torch.set_default_device('cuda')
torch.serialization.add_safe_globals([deque])
memory = deque()
model = Linear_QNet(247, 247, 64)
saved_model = torch.load('./state.pt')
model.load_state_dict(saved_model['model_state'])
memory.extend(saved_model['memory'])
model.train()
states, actions, rewards, next_states, not_dones = zip(*memory)
optimizer = optim.AdamW(model.parameters(), lr=1e-3)
criterion = nn.MSELoss()
states = torch.tensor(states, dtype=torch.float)
next_states = torch.tensor(next_states, dtype=torch.float)
not_dones = torch.tensor(not_dones, dtype=torch.bool)
rewards = torch.tensor(rewards, dtype=torch.float)
actions = torch.tensor(actions, dtype=torch.int64)
# Warm up
model(states)
with profiler.profile(activities=[profiler.ProfilerActivity.CPU, profiler.ProfilerActivity.CUDA]) as prof:
pred = model(states)
next_state_predictions = model(next_states)
with profiler.record_function("MISC TRANSFORMS"):
next_model_states = torch.max(next_state_predictions, dim=1, keepdim=True)[0]
next_model_states.squeeze_(1).mul_(0.9)
next_model_states.mul_(not_dones)
next_model_states.add_(rewards)
next_model_states_length = len(next_model_states)
target = pred.clone()
target.scatter_(1, actions.view(next_model_states_length, 1), next_model_states.unsqueeze_(0).view(next_model_states_length, 1))
optimizer.zero_grad()
with profiler.record_function("LOSS FUNCTION"):
loss = criterion(target, pred)
with profiler.record_function("BACKWARD"):
loss.backward()
optimizer.step()
print(prof.key_averages().table(sort_by='cuda_time_total', row_limit=10))
print(prof.key_averages().table(sort_by='cpu_time_total', row_limit=10))
Result:
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
autograd::engine::evaluate_function: AddmmBackward0 0.00% 464.000us 2.91% 851.084ms 106.386ms 0.000us 0.00% 20.746s 2.593s 8
AddmmBackward0 0.00% 258.500us 0.05% 15.207ms 1.901ms 0.000us 0.00% 20.234s 2.529s 8
aten::mm 0.00% 1.429ms 0.05% 14.616ms 1.044ms 20.234s 71.92% 20.234s 1.445s 14
volta_sgemm_32x32_sliced1x4_nt 0.00% 0.000us 0.00% 0.000us 0.000us 11.169s 39.70% 11.169s 2.792s 4
volta_sgemm_64x64_nn 0.00% 0.000us 0.00% 0.000us 0.000us 8.014s 28.49% 8.014s 2.004s 4
FORWARD PASS 0.00% 0.000us 0.00% 0.000us 0.000us 6.488s 23.06% 6.488s 3.244s 2
FORWARD PASS 0.01% 1.608ms 0.84% 245.611ms 122.806ms 0.000us 0.00% 6.467s 3.234s 2
aten::linear 0.00% 96.100us 0.07% 19.888ms 2.486ms 0.000us 0.00% 3.568s 445.955ms 8
aten::addmm 0.00% 1.333ms 0.07% 19.515ms 2.439ms 3.545s 12.60% 3.568s 445.955ms 8
aten::cat 0.00% 424.300us 0.76% 223.331ms 55.833ms 2.712s 9.64% 2.712s 677.988ms 4
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 29.248s
Self CUDA time total: 28.133s
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
cudaDeviceSynchronize 70.63% 20.657s 70.63% 20.657s 20.657s 0.000us 0.00% 0.000us 0.000us 1
cudaLaunchKernel 3.25% 952.014ms 24.96% 7.300s 100.002ms 0.000us 0.00% 271.087ms 3.714ms 73
Runtime Triggered Module Loading 24.94% 7.294s 24.94% 7.294s 521.005ms 180.699ms 0.64% 180.699ms 12.907ms 14
MISC TRANSFORMS 0.00% 1.165ms 21.53% 6.297s 6.297s 0.000us 0.00% 149.373ms 149.373ms 1
aten::max 0.00% 210.700us 21.35% 6.246s 6.246s 25.323ms 0.09% 101.293ms 101.293ms 1
BACKWARD 0.16% 45.474ms 3.40% 994.365ms 994.365ms 0.000us 0.00% 0.864us 0.864us 1
autograd::engine::evaluate_function: AddmmBackward0 0.00% 464.000us 2.91% 851.084ms 106.386ms 0.000us 0.00% 20.746s 2.593s 8
aten::sum 0.00% 505.700us 2.86% 835.087ms 104.386ms 512.140ms 1.82% 512.140ms 64.018ms 8
cudaMalloc 0.95% 277.817ms 0.95% 277.817ms 15.434ms 0.000us 0.00% 0.000us 0.000us 18
FORWARD PASS 0.01% 1.608ms 0.84% 245.611ms 122.806ms 0.000us 0.00% 6.467s 3.234s 2
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 29.248s
Self CUDA time total: 28.133s
@ptrblck Thank you again for your help. Very much appreciated!