How to parralelize scattered layer norm?

I’m working with sets of different sizes with input dimensions [i,e] where i is the set element and e is the embdding. To avoid zero padding I batch along i and use the pytorch scatter library Scatter — pytorch_scatter 2.1.1 documentation to to take the mean respecting the batching.

I’m having a lot more difficulty with getting layernorm to work.
The input to my layer norm function is my batch with dimensions [i,e] and my “batch_mask” which provides the indicies where one batch ends and a new one begins. I take the mean of each individual batch, the std of each individual batch, then concatenate the results back together.

This is extremely slow and my gpu utilization is at 5% as a result. Is there a way I can make the following operation parralel given that the step size varies?

new_x =
for i,j in zip(Batch_Mask[“iterate”][:-1],Batch_Mask[“iterate”][1:]):
_ = x[i:j]-torch.mean(x[i:j])
new_x.append(_/torch.sqrt(torch.std(x[i:j])+eps))