nn.MultiheadAttention throwing NaNs for entire batch

Hey guys,

I’ve begun using torch’s latest MHA and noticed some differences, where by adding some NaNs as an input tensor for forward pass returns an output tensor full of NaNs. Using my default implementation, I would only get NaNs for the NaNs passed in the input tensor. Here’s how I reproduced this:

from typing import Optional
import torch
import torch.nn as nn
from einops import rearrange

class MultiheadAttention(nn.Module):
    def __init__(self, dim: int, heads: int, dim_head: Optional[int] = None):
        super().__init__()
        self.dim_head = dim // heads if dim_head is None else dim_head
        self.heads = heads
        self.scale_factor = self.dim_head**-0.5
        self.to_qvk = nn.Linear(dim, self.dim_head * heads * 3, bias=False)
        self.to_out = nn.Linear(self.dim_head * heads, dim, bias=False)

    def forward(
        self, x: torch.Tensor, mask: torch.Tensor = None
    ) -> torch.Tensor:
        qkv = self.to_qvk(x)
        q, k, v = tuple(
            rearrange(qkv, "b t (d h k) -> k b h t d ", k=3, h=self.heads)
        )
        scaled_dot_prod = (
            torch.einsum("b h i d , b h j d -> b h i j", q, k)
            * self.scale_factor
        )
        if mask is not None:
            assert mask.shape == scaled_dot_prod.shape[2:]
            scaled_dot_prod = scaled_dot_prod.masked_fill(mask, -torch.inf)
        attention = torch.softmax(scaled_dot_prod, dim=-1)
        context = torch.einsum("b h i j , b h j d -> b h i d", attention, v)
        context = rearrange(context, "b h t d -> b t (h d)")
        out = self.to_out(context)
        return out

The following would pass:

size = 2
tokens = 2
dim = 8
heads = 2
inputs = torch.randn((size, tokens, dim), requires_grad=True).cuda()

model = MultiheadAttention(dim=dim, heads=heads).cuda()
outputs = model(inputs)
assert not torch.isnan(outputs).any()
assert not torch.isinf(outputs).any()

with_nans = torch.cat((inputs, torch.ones_like(inputs) * float("nan")))
outputs_with_nans = model(with_nans)

print(outputs_with_nans.shape)
print(torch.isnan(outputs_with_nans).sum().item())
print(outputs.shape[0] * outputs.shape[1] * outputs.shape[2])
assert torch.isnan(outputs_with_nans).sum() == outputs.shape[0] * outputs.shape[1] * outputs.shape[2]

This would fail, which is the problem:

size = 2
tokens = 2
dim = 8
heads = 2
inputs = torch.randn((size, tokens, dim), requires_grad=True).cuda()

model = nn.MultiheadAttention(embed_dim=dim, num_heads=heads).cuda()

with torch.backends.cuda.sdp_kernel(enable_flash=True):
    outputs = model(inputs, inputs, inputs, need_weights=False)[0]
assert not torch.isnan(outputs).any()
assert not torch.isinf(outputs).any()

with_nans = torch.cat((inputs, torch.ones_like(inputs) * float("nan")))

with torch.backends.cuda.sdp_kernel(enable_flash=True):
    outputs_with_nans = model(with_nans, with_nans, with_nans, need_weights=False)[0]

print(outputs_with_nans.shape)
print(torch.isnan(outputs_with_nans).sum().item())
print(outputs.shape[0] * outputs.shape[1] * outputs.shape[2])
assert torch.isnan(outputs_with_nans).sum() == outputs.shape[0] * outputs.shape[1] * outputs.shape[2]