Hi!
I have a question related to speeding up a model I have written involving broadcasting complicated pointwise functions over large tensors with many (~5) dimensions.
In the model I have functions like this (there are 4 different functions of equivalent complexity),
@torch.jit.script
def I_k_phi(t, alpha, p1, z1, thet2, beta2):
pi = math.pi
coeff = torch.sqrt(pi / (alpha + p1))
ea1 = torch.exp(-(4 * alpha * p1 * z1 ** 2 + thet2 ** 2) / (4 * (alpha + p1)))
ea2 = torch.cos(beta2 + thet2 * (t - (p1 * z1) / (alpha + p1)))
return coeff * ea1 * ea2
The tensors I am mapping over are of shape (S, B, D, N1, N2), where S is O(1), B is O(1000), and N1 and N2 are O(100). The problem is that for some datasets I require D to be ~8, but the model uses more memory than my GPU can handle for D>6. After doing some profiling, the memory becomes a problem only for the backward pass. I have also tried replacing these complex functions with much simpler (but incorrect) functions, and the memory problems completely disappear.
An idea I had to help reduce the memory use is to implement the gradients of the complex functions by hand, using a custom torch.autograd.Function
. Obviously this will take a lot of work, but I am happy to do it if I can be reasonably confident it will help. Is it likely this will help, or am I confused about how the autograd engine works?
Any help with this or alternative ideas for solutions would be much appreciated.
Thanks in advance!