Speeding up autograd for very complex pointwise tensor operations

Hi!

I have a question related to speeding up a model I have written involving broadcasting complicated pointwise functions over large tensors with many (~5) dimensions.

In the model I have functions like this (there are 4 different functions of equivalent complexity),

@torch.jit.script
def I_k_phi(t, alpha, p1, z1, thet2, beta2):
    pi = math.pi
    coeff = torch.sqrt(pi / (alpha + p1))
    ea1 = torch.exp(-(4 * alpha * p1 * z1 ** 2 + thet2 ** 2) / (4 * (alpha + p1)))
    ea2 = torch.cos(beta2 + thet2 * (t - (p1 * z1) / (alpha + p1)))
    return coeff * ea1 * ea2

The tensors I am mapping over are of shape (S, B, D, N1, N2), where S is O(1), B is O(1000), and N1 and N2 are O(100). The problem is that for some datasets I require D to be ~8, but the model uses more memory than my GPU can handle for D>6. After doing some profiling, the memory becomes a problem only for the backward pass. I have also tried replacing these complex functions with much simpler (but incorrect) functions, and the memory problems completely disappear.

An idea I had to help reduce the memory use is to implement the gradients of the complex functions by hand, using a custom torch.autograd.Function. Obviously this will take a lot of work, but I am happy to do it if I can be reasonably confident it will help. Is it likely this will help, or am I confused about how the autograd engine works?

Any help with this or alternative ideas for solutions would be much appreciated.

Thanks in advance!