How to efficiently compute gradient for each training sample?

It should work for those architectures as well. What’s missing is support for other layers with trainable parameters, (ie the multiheadattention layer)