Apply function along axis


I have a function that work on a tensor of shape (B,1) and return (B,1).
I want to apply the same function across a tensor of shape (B,S,1) along the dimension S.

How to do that in torch ?

Thanks you

It depends on the used function, but broadcasting might just work.