Applying functions to certain columns

Hi,
I have problem that requires applying function objects (like torch.sin) to some columns of a matrix X. Right now I solve this by having a list of tuples like this: [(function, col)] where function is the function object and col is the column this function acts on.
In each step I stack the resulting columns and return the matrix. Now I’m not sure this is the most performant way because it’s doing a list comprehension over all columns for every evaluation. Doing something like

X[:,i] = torch.sin(X[:,i]) 

is also not possible because it breaks autograd. I was thinking the new torch.vmap could somehow help here?
Another problem with my approach is that due to the list mentioned above I can’t jit the module because the functor datatype is not supported.

Do you have any tips on how to improve the implementation?