Efficient implementation of softmax/attention and resnet updates

I’m reasonably up-to-speed with NN architectures but I could do with some pointers to efficient implementations for softmax and resnet.

Let’s say I have a million vectors and I want to propagate these a thousand time steps into the future with attention selection and then backprop the lot. At each step I compute a distance from each of a million keys then do a big softmax to get my attention weighting, a[t,i]. I now update my vectors with v[t+1,i] = (1 - a[t,i]) * v[t,i] + a[t,i] * d[t,i], and go on to do this another 999 times.

Now most a[t,i] will be zero, so all of this is really, really inefficient. Let’s not worry about the computation of a[t,i], let’s assume that SVD or something similar can give me a shortlist and I’m happy pruning out those that are nearly zero. Let’s say that this gives me ten non-zero a[t,i] that I care about, the rest I’m happy to treat as zero.

What I’m thinking is that I store the time step, n[t,i] in which the last non-zero update occured. I can then write v[t+1,i] = (1 - a[t,i]) * v[n[t,i],i] + a[t,i] * d[t,i] and only do this for non-zero a[t,i]. But I don’t know if that’s the best way to do it in PyTorch, it’s certainly not very matrix freindly.

Softmax/attention is designed to be sparse, transformers are dominating our modelling, surely this is a problem that many people have had. So, am I thinking along the right lines, and if so can someone point me to some papers/implementations that with speed my PyTorch coding? Thanks!