Regarding Scaled Dot Product Attention

I was trying to implement and write the code for Attention Computation from scratch. So, I used torch.einsum for matrix multiplication between Query and Key Vectors. This step always threw CUDA OOM errors, and when I used F.scaled_dot_product_attention, my model was working fine and didn’t even throw any OOM errors. Can anyone help me understand torch.einsum is a GPU memory-intensive operation? And, Why is the torch implementation for scaled_dot_product_attention a less GPU memory-intensive operation?

Code for My Scratch Implementation -

Using torch.einsum :

class ConvMHAttn_Scratch(nn.Module):
    """
    Convolutional Multi-Head Attention Network
    """
    def __init__(self, in_channels, out_channels, num_heads):
        super(ConvMHAttn_Scratch, self).__init__()
        assert out_channels % num_heads == 0, "out_channels must be divisible by num_heads"
        self.num_heads = num_heads
        self.head_dim = out_channels // num_heads  # Each head's dimensionality
        self.q_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)  # Query projection
        self.k_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)  # Key projection
        self.v_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)  # Value projection
        self.out_conv = nn.Conv2d(out_channels, out_channels, kernel_size=1)  # Output projection

    def forward(self, query, key, value):
        assert query.shape == key.shape, "Query and Key must have the same shape"
        assert key.shape == value.shape, "Key and Value must have the same shape"
        B, Cq, Hq, Wq = query.shape  # Query shape (target input)
        B, Ckv, Hkv, Wkv = key.shape  # Key-Value shape (source input)
        
        # Step 1: Create Query, Key, and Value
        Q = self.q_conv(query).view(B, self.num_heads, self.head_dim, Hq * Wq)  # (B, heads, C_head, Hq*Wq)
        K = self.k_conv(key).view(B, self.num_heads, self.head_dim, Hkv * Wkv)  # (B, heads, C_head, Hkv*Wkv)
        V = self.v_conv(value).view(B, self.num_heads, self.head_dim, Hkv * Wkv)  # (B, heads, C_head, Hkv*Wkv)
        # Step 2: Scaled Dot-Product Attention
        print("Q shape: ", Q.shape)
        print("K shape: ", K.shape)
        print("V shape: ", V.shape)
        attn = torch.einsum('bhci,bhcj->bhij', Q, K) / (self.head_dim ** 0.5)  # (B, heads, H*W, H*W)
        attn = torch.softmax(attn, dim=-1)  # (B, heads, H*W, H*W)

        # Step 3: Apply attention to V
        out = torch.einsum('bhij,bhcj->bhci', attn, V)  # (B, heads, C_head, H*W)
        print("Out Shape 1",out.shape)
        out = out.reshape(B, -1, Hq, Wq)  # (B, C, H, W) combining heads
        print("Out Shape 2",out.shape)

        # Step 4: Final projection (combining heads)
        out = self.out_conv(out)
        return out

Using F.scaled_dot_product_attention:

class ConvMHAttn(nn.Module):
    """
    Convolutional Multi-Head Attention Network
    """
    def __init__(self, in_channels, out_channels, num_heads):
        super(ConvMHAttn, self).__init__()
        assert out_channels % num_heads == 0, "out_channels must be divisible by num_heads"
        self.num_heads = num_heads
        self.head_dim = out_channels // num_heads  # Each head's dimensionality
        self.q_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)  # Query projection
        self.k_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)  # Key projection
        self.v_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)  # Value projection
        self.out_conv = nn.Conv2d(out_channels, out_channels, kernel_size=1)  # Output projection

    def forward(self, query, key, value):
        assert query.shape == key.shape, "Query and Key must have the same shape"
        assert key.shape == value.shape, "Key and Value must have the same shape"
        B, Cq, Hq, Wq = query.shape  # Query shape (target input)
        B, Ckv, Hkv, Wkv = key.shape  # Key-Value shape (source input)
        
        # Step 1: Create Query, Key, and Value
        Q = self.q_conv(query).view(B, self.num_heads, self.head_dim, Hq * Wq)  # (B, heads, C_head, Hq*Wq)
        K = self.k_conv(key).view(B, self.num_heads, self.head_dim, Hkv * Wkv)  # (B, heads, C_head, Hkv*Wkv)
        V = self.v_conv(value).view(B, self.num_heads, self.head_dim, Hkv * Wkv)  # (B, heads, C_head, Hkv*Wkv)
        # Step 2: Scaled Dot-Product Attention
        out = F.scaled_dot_product_attention(query=Q, key=K, value=V, dropout_p=0.2)
        out = out.reshape(B, -1, Hq, Wq)  # (B, C, H, W) combining heads
        # Step 4: Final projection (combining heads)
        out = self.out_conv(out)
        return out
1 Like

The documentation states that PyTorch’s scaled dot product attention has memory efficient attention enabled by default:

https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html

See here for benchmarks and code:

1 Like

@J_Johnson Thank you for your reply.
Can you help me understand whether torch.einsum or einsops is memory-intensive on the GPU?

I don’t think that’s at issue here. The issue is the original attention mechanism was very memory intensive. More memory efficient versions have been proposed, which get you the same or “close enough” result with fewer calculations.

There are a lot of shortcuts with Linear Algebra calculations that can be used for efficiency. Even something is simple as matrix multiplication can have the calculations simplified. Whereas using einsum on Q, K and then on V is straightforward but makes no use of shortcuts.

1 Like

@Jivitesh_Sabharwal1 torch einsum aka matmul for QK^T is quadratic in memory i.e needs O(S^2) memory.

its quite slow as well due to lots of loads and store that touch global memory on the GPU.
Don’t worry if you don’t know too much about GPU anatomy.

TLDR is that a naive implementation for attention uses O(S^2) memory and O(S^2) while something more efficient like Flash Attention uses O(S) memory and O(S^2) time.

Algorithmically, the time is quadratic for both but Flash Attention is much faster because it reduces loads/store from memory.

reference paper: [2205.14135] FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness

2 Likes