Check out a package I made for this purpose: GitHub - toshas/torch-discounted-cumsum: Fast Discounted Cumulative Sums in PyTorch
Check out a package I made for this purpose: GitHub - toshas/torch-discounted-cumsum: Fast Discounted Cumulative Sums in PyTorch