Correct way to batch custom masks in SDPA

Hello everyone,
I’m trying to pass F.scaled_dot_product_attention a custom mask, which differ for every element in the batch, I couldn’t figure out a way to batch it.
Thanks!