Hi all,
attention_logits is of size torch.Size([16, 12, 260, 260])
attention_mask is of size torch.Size([16, 260])
i want to perform attention_logits = attention_logits.masked_fill(attention_mask == 0, -1e7)
but it throws RuntimeError: The size of tensor a (16) must match the size of tensor b (260) at non-singleton dimension 2
because of broadcasting issues.
what’s the best way to perform this masked_fill on the rows of the attention_logits tensor?