I would like to use a custom attention mask for each element of my batch. Let’s suppose I have a batch size of 2, with a single query and 3 keys. Then, suppose my custom mask looks like:

```
mask = [
[[-inf, 0, 0]],
[[0, -inf, 0]]
]
```

Now, suppose that my MHA layer is using N heads, the documentation says that the custom mask should be of size (2 * N, 1, 3). How should I set up my mask? would I repeat like `mask = mask.repeat(N, 1, 1)`

, or would I need to make it like

```
mask = [
[[-inf, 0, 0]],
...
...
...
[[-inf, 0, 0]],
[[0, -inf, 0]],
...
...
...
[[0, -inf, 0]]
]
```