While PyTorch supports the computation of the matrix exponential (mexp) since version 1.7, it does not seem to feature the action of the matrix exponential on a few specific vectors. This feature is implemented in scipy for instance, link. In practice, computing the whole mexp may be unnecessarily costly. This led me to try and implement the algorithm proposed in “Computing the Action of the Matrix Exponential” by Al-Mohy and Higham link. In my first attempt, I implemented a function and thus exploited the underlying forward and backward inherent support of PyTorch. Unfortunately, I suspect that the memory footprint of my implementation is not ideal.
The main question is what can be done to improve the memory footprint? Initially, I am tempted to implement it as a new Function and extending torch.autograd either in Python or C++. However, it is unclear whether I will gain any significant memory reduction. The algorithm above is a “scaling-and-squaring” algorithm, in which the action of mexp is computed iteratively. At each iteration the vectors are multiplied by the scaled matrix, followed by a sum of the results. Please see the code below. Now, while auto-differentiating (AD) this computation in forward-mode is somewhat efficient as you can perform it greedily, this is unfortunately not the case in reverse-mode. Basically, one needs to store all the intermediate scaled versions of the matrix, although, I might be missing something here.
Any tips on how to proceed would be highly appreciated.
Below, is a redacted version of the algorithm, I tried to keep the “memory-wise important” parts of the algorithm:
def expmv(A, B, t, s, m): # A is the matrix for mexp, B is a few vectors F = B for ii in range(s): for jj in range(m): B = (t / (s * (jj+1))) * (A @ B.T).T F = F + B B = F return F