Flex attention and SDPA output natively equivalent?

I’d expect you to be able to use regular batched inputs. Maybe the examples here would be helpful attention-gym/examples/flex_attn.ipynb at main · pytorch-labs/attention-gym · GitHub