Based on this thread, I found a way to eliminate the inner for loop using bmm. Profiling indicates this has removed a lot of work from the CPU (especially the backwards pass) and has resulted in a considerable speedup.
Based on this thread, I found a way to eliminate the inner for loop using bmm. Profiling indicates this has removed a lot of work from the CPU (especially the backwards pass) and has resulted in a considerable speedup.