Dynamic shape in vmap

I am trying to implement an algorithm that needs to deal with dynamic shapes for every sample in a batch. e.g: only process a part of the data that exceeds a threshold in the sample and place back to its original position after processing the data.
I first tried to do this by using vmap. However, it seems like vmap does not support dynamic shapes.
Though I could work around this problem by sequentially processing data sample by sample, is there any better solution, or any future release that will probably solve this issue?


Sounds like you probably want to implement a mask, something like this

mask = (input_tensor > threshold).float()
output_tensor = mask * custom_operation_if_mask_is_true(input_tensor) + (1 - mask) * input_tensor