Using custom attention masks

I would like to use a custom attention mask for each element of my batch. Let’s suppose I have a batch size of 2, with a single query and 3 keys. Then, suppose my custom mask looks like:

mask = [
        [[-inf, 0, 0]],
        [[0, -inf, 0]]
       ]

Now, suppose that my MHA layer is using N heads, the documentation says that the custom mask should be of size (2 * N, 1, 3). How should I set up my mask? would I repeat like mask = mask.repeat(N, 1, 1), or would I need to make it like

mask = [
        [[-inf, 0, 0]],
        ...
        ...
        ...
        [[-inf, 0, 0]],
        [[0, -inf, 0]],
        ...
        ...
        ...
        [[0, -inf, 0]]
       ]