StyleGAN2 Torchscript model in C++ runs half as fast as in Python

I’ve successfully converted the StyleGAN2 model from https://github.com/NVlabs/stylegan2-ada-pytorch to Torchscript by using the reference bias_act and upfirdn2d. However, I’ve found that when running the model in Python I get about 2x better performance than in C++. Here is my benchmarking code for both of them:

C++:

#include <torch/torch.h>
#include <torch/script.h>
#include <iostream>

int main() {
    torch::NoGradGuard guard;
    torch::jit::script::Module model;
    torch::TensorOptions opts(torch::kCUDA);
    try {
        // Deserialize the ScriptModule from a file using torch::jit::load().
        model = torch::jit::load("stylegan.pt");
        model.eval();
        model.to(at::kCUDA);


        std::vector<torch::jit::IValue> inputs;
        inputs.emplace_back(torch::rand({1, 512}, opts));
        double out = 0;
        for (int i = 0; i < 20; ++i) {
            out += model.forward(inputs).toTensor().mean().item<double>();
        }
        auto start = clock();
        for (int i = 0; i < 200; ++i) {
            out += model.forward(inputs).toTensor().mean().item<double>();
        }
        std::cout << float( clock () - start ) /  CLOCKS_PER_SEC << "\n";
        std::cout << out << "\n";
    }
    catch (const c10::Error& e) {
        std::cerr << e.what();
        return -1;
    }
}

Python:

import time
import torch

DEVICE = torch.device("cuda")

def main():
    model = torch.jit.load("stylegan.pt")
    model.eval()
    model.to(DEVICE)

    inputs = [torch.rand((1,512), device=DEVICE)]

    out = 0
    for i in range(20):
        out += model.forward(*inputs).mean().item()

    start = time.time()
    for i in range(200):
        out += model.forward(*inputs).mean().item()
    print(time.time() - start)
    print(out)

if __name__ == '__main__':
    main()

I think that these benchmarks should be equivalent, but my Python version takes only 22 seconds, while my C++ version takes 45 seconds. I’m running this on Windows; here’s my CMake output:

-- The C compiler identification is MSVC 19.29.30140.0
-- The CXX compiler identification is MSVC 19.29.30140.0
-- Detecting C compiler ABI info
-- Detecting C compiler ABI info - done
-- Check for working C compiler: C:/Program Files (x86)/Microsoft Visual Studio/2019/BuildTools/VC/Tools/MSVC/14.29.30133/bin/Hostx64/x64/cl.exe - skipped
-- Detecting C compile features
-- Detecting C compile features - done
-- Detecting CXX compiler ABI info
-- Detecting CXX compiler ABI info - done
-- Check for working CXX compiler: C:/Program Files (x86)/Microsoft Visual Studio/2019/BuildTools/VC/Tools/MSVC/14.29.30133/bin/Hostx64/x64/cl.exe - skipped
-- Detecting CXX compile features
-- Detecting CXX compile features - done
variable name is C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v11.5/lib
-- Looking for pthread.h
-- Looking for pthread.h - not found
-- Found Threads: TRUE  
-- Found CUDA: C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v11.5 (found version "11.5") 
-- Caffe2: CUDA detected: 11.5
-- Caffe2: CUDA nvcc is: C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v11.5/bin/nvcc.exe
-- Caffe2: CUDA toolkit directory: C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v11.5
-- Caffe2: Header version is: 11.5
-- Found CUDNN: C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v11.5/lib  
-- Found cuDNN: v8.3.1  (include: C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v11.5/include, library: C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v11.5/lib)
-- C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v11.5/lib/x64/nvrtc.lib shorthash is dd482e34
-- Autodetected CUDA architecture(s):  6.1
-- Added CUDA NVCC flags for: -gencode;arch=compute_61,code=sm_61
-- Found Torch: C:/libtorch/lib/torch.lib  
-- Configuring done
WARNING: Target "torchscriptspeedtest" requests linking to directory "C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v11.5/lib".  Targets may link only to libraries.  CMake is dropping the item.
-- Generating done

Is there any reason why I should expect the C++ one to be less performant? It looks like both of them are able to use all of the features available to them.

CUDA operations are executed asynchronously so you would need to synchronize the code before starting and stopping the timers.

This would be an issue if it wasn’t for the .item() calls forcing the data to the CPU, implicitly synchronizing the CUDA device.

hello How did you convert to stylegan.pt? Why did my c++ reasoning fail after torch.jit.trace conversion