Model() uses GPU but backwards() doesn't

Sorry, this was a false alarm. Thank you for pointing me towards doing proper profiling! :slight_smile:

In the end, it turns out that the GPU is being used for both the forwards and backwards passes, as expected, and the piece that was using CPU was my call to max(), because it returns a Python tuple:

next_model_states = torch.max(next_state_predictions, dim=1, keepdim=True)[0]

I think I should be able to replace this with amax():

next_model_states = next_state_predictions.amax(dim=1, keepdim=True)

In case it’s helpful to anyone, here’s the full code that I used to test this:

import torch
import torch.nn as nn
import torch.optim as optim
from collections import deque
import torch.profiler as profiler

class Linear_QNet(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.linear1 = nn.Linear(input_size, hidden_size)
        self.linear2 = nn.Linear(hidden_size + input_size, hidden_size)
        self.linear3 = nn.Linear(hidden_size + input_size, hidden_size)
        self.linear7 = nn.Linear(hidden_size, output_size)

        self.relu = nn.ReLU(inplace=True)

    def forward(self, original):
        with profiler.record_function("FORWARD PASS"):
            x = self.linear1(original)
            x = self.relu(x)
            x = self.linear2(torch.cat((x, original), 1))
            x = self.relu(x)
            x = self.linear3(torch.cat((x, original), 1))
            x = self.relu(x)
            x = self.linear7(x)
        
        return x
    
if __name__ == '__main__':
    torch.set_default_device('cuda')

    torch.serialization.add_safe_globals([deque])

    memory = deque()
    
    model = Linear_QNet(247, 247, 64)

    saved_model = torch.load('./state.pt')

    model.load_state_dict(saved_model['model_state'])
    memory.extend(saved_model['memory'])

    model.train()

    states, actions, rewards, next_states, not_dones = zip(*memory)

    optimizer = optim.AdamW(model.parameters(), lr=1e-3)
    criterion = nn.MSELoss()

    states = torch.tensor(states, dtype=torch.float)
    next_states = torch.tensor(next_states, dtype=torch.float)
    not_dones = torch.tensor(not_dones, dtype=torch.bool)
    rewards = torch.tensor(rewards, dtype=torch.float)
    actions = torch.tensor(actions, dtype=torch.int64)

    # Warm up
    model(states)

    with profiler.profile(activities=[profiler.ProfilerActivity.CPU, profiler.ProfilerActivity.CUDA]) as prof:
        pred = model(states)
        next_state_predictions = model(next_states)

        with profiler.record_function("MISC TRANSFORMS"):
            next_model_states = torch.max(next_state_predictions, dim=1, keepdim=True)[0]
            next_model_states.squeeze_(1).mul_(0.9)
            next_model_states.mul_(not_dones)
            next_model_states.add_(rewards)
            next_model_states_length = len(next_model_states)
            target = pred.clone()
            target.scatter_(1, actions.view(next_model_states_length, 1), next_model_states.unsqueeze_(0).view(next_model_states_length, 1))
            optimizer.zero_grad()

        with profiler.record_function("LOSS FUNCTION"):
            loss = criterion(target, pred)

        with profiler.record_function("BACKWARD"):
            loss.backward()

    optimizer.step()

    print(prof.key_averages().table(sort_by='cuda_time_total', row_limit=10))
    print(prof.key_averages().table(sort_by='cpu_time_total', row_limit=10))

Result:

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
    autograd::engine::evaluate_function: AddmmBackward0         0.00%     464.000us         2.91%     851.084ms     106.386ms       0.000us         0.00%       20.746s        2.593s             8  
                                         AddmmBackward0         0.00%     258.500us         0.05%      15.207ms       1.901ms       0.000us         0.00%       20.234s        2.529s             8  
                                               aten::mm         0.00%       1.429ms         0.05%      14.616ms       1.044ms       20.234s        71.92%       20.234s        1.445s            14  
                         volta_sgemm_32x32_sliced1x4_nt         0.00%       0.000us         0.00%       0.000us       0.000us       11.169s        39.70%       11.169s        2.792s             4  
                                   volta_sgemm_64x64_nn         0.00%       0.000us         0.00%       0.000us       0.000us        8.014s        28.49%        8.014s        2.004s             4  
                                           FORWARD PASS         0.00%       0.000us         0.00%       0.000us       0.000us        6.488s        23.06%        6.488s        3.244s             2  
                                           FORWARD PASS         0.01%       1.608ms         0.84%     245.611ms     122.806ms       0.000us         0.00%        6.467s        3.234s             2  
                                           aten::linear         0.00%      96.100us         0.07%      19.888ms       2.486ms       0.000us         0.00%        3.568s     445.955ms             8  
                                            aten::addmm         0.00%       1.333ms         0.07%      19.515ms       2.439ms        3.545s        12.60%        3.568s     445.955ms             8  
                                              aten::cat         0.00%     424.300us         0.76%     223.331ms      55.833ms        2.712s         9.64%        2.712s     677.988ms             4  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 29.248s
Self CUDA time total: 28.133s

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                  cudaDeviceSynchronize        70.63%       20.657s        70.63%       20.657s       20.657s       0.000us         0.00%       0.000us       0.000us             1  
                                       cudaLaunchKernel         3.25%     952.014ms        24.96%        7.300s     100.002ms       0.000us         0.00%     271.087ms       3.714ms            73  
                       Runtime Triggered Module Loading        24.94%        7.294s        24.94%        7.294s     521.005ms     180.699ms         0.64%     180.699ms      12.907ms            14  
                                        MISC TRANSFORMS         0.00%       1.165ms        21.53%        6.297s        6.297s       0.000us         0.00%     149.373ms     149.373ms             1  
                                              aten::max         0.00%     210.700us        21.35%        6.246s        6.246s      25.323ms         0.09%     101.293ms     101.293ms             1  
                                               BACKWARD         0.16%      45.474ms         3.40%     994.365ms     994.365ms       0.000us         0.00%       0.864us       0.864us             1  
    autograd::engine::evaluate_function: AddmmBackward0         0.00%     464.000us         2.91%     851.084ms     106.386ms       0.000us         0.00%       20.746s        2.593s             8  
                                              aten::sum         0.00%     505.700us         2.86%     835.087ms     104.386ms     512.140ms         1.82%     512.140ms      64.018ms             8  
                                             cudaMalloc         0.95%     277.817ms         0.95%     277.817ms      15.434ms       0.000us         0.00%       0.000us       0.000us            18  
                                           FORWARD PASS         0.01%       1.608ms         0.84%     245.611ms     122.806ms       0.000us         0.00%        6.467s        3.234s             2  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 29.248s
Self CUDA time total: 28.133s

@ptrblck Thank you again for your help. Very much appreciated! :slight_smile: