Reducing computational overhead when caching data for the backward pass

Hello, I am trying to optimize a custom variant for a recurrent layer. I made my initial implementation in Python with PyTorch and now I am reimplementing it with the C++ frontend, hoping that I can get some speed improvements when wrapping the C++ implementation inside torch.autograd.Function. So far without success.

I found that the overhead results from preparing / packing some intermediate results for the backward pass. Without such steps, my C++ code runs about 10% faster than my PyTorch implementation in Python. Now, I am wondering what are some best practices that I could apply to avoid such overheads? For now, I am only working on the CPU, before moving forward with a GPU implementation.

In a first attempt I used std::vectors to keep track of the results, and at the end of my forward pass I used torch::stack or torch::cat to combine them to a single tensor. Since this is a lot of copying it is no surprise to me that C++ version is slower than the one in Python.

Next, I thought it would be a good idea to pre-allocate one big tensor to cache all my intermediate results. Below is a very basic layout that illustrates my attempt:

std::vector<torch::Tensor> forward_pass(torch::Tensor input, torch::Tensor weights, torch::Tensor state)
    auto intermediate_results = torch::empty( { 5, input.size(0), input.size(1), state.size(1), state.size(2), state.size(3) });

    // Iterate over each time step
    for (int t = 0; t < input.size(1); ++t)
        intermediate_results.index_put_({0, t, Slice(), Slice(), Slice(), Slice()}, /* some data tensor for the first result */);

        // ...

    return { output, new_state, intermediate_results };

I was hoping that I could avoid copying and merging my tensors by pointing directly to a pre-allocated memory tensor, but I guess I was a bit naive here. I wanted to start this thread asking for some tipps and tricks, and sharing some of your experiences on how to tackle problems like these. Maybe it isn’t even a good idea to pack the results into a single tensor in the first place… but I would be happy for any suggestions or ideas.