I was trying to implement and write the code for Attention Computation from scratch. So, I used torch.einsum
for matrix multiplication between Query and Key Vectors. This step always threw CUDA OOM errors, and when I used F.scaled_dot_product_attention
, my model was working fine and didn’t even throw any OOM errors. Can anyone help me understand torch.einsum
is a GPU memory-intensive operation? And, Why is the torch implementation for scaled_dot_product_attention
a less GPU memory-intensive operation?
Code for My Scratch Implementation -
Using torch.einsum
:
class ConvMHAttn_Scratch(nn.Module):
"""
Convolutional Multi-Head Attention Network
"""
def __init__(self, in_channels, out_channels, num_heads):
super(ConvMHAttn_Scratch, self).__init__()
assert out_channels % num_heads == 0, "out_channels must be divisible by num_heads"
self.num_heads = num_heads
self.head_dim = out_channels // num_heads # Each head's dimensionality
self.q_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) # Query projection
self.k_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) # Key projection
self.v_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) # Value projection
self.out_conv = nn.Conv2d(out_channels, out_channels, kernel_size=1) # Output projection
def forward(self, query, key, value):
assert query.shape == key.shape, "Query and Key must have the same shape"
assert key.shape == value.shape, "Key and Value must have the same shape"
B, Cq, Hq, Wq = query.shape # Query shape (target input)
B, Ckv, Hkv, Wkv = key.shape # Key-Value shape (source input)
# Step 1: Create Query, Key, and Value
Q = self.q_conv(query).view(B, self.num_heads, self.head_dim, Hq * Wq) # (B, heads, C_head, Hq*Wq)
K = self.k_conv(key).view(B, self.num_heads, self.head_dim, Hkv * Wkv) # (B, heads, C_head, Hkv*Wkv)
V = self.v_conv(value).view(B, self.num_heads, self.head_dim, Hkv * Wkv) # (B, heads, C_head, Hkv*Wkv)
# Step 2: Scaled Dot-Product Attention
print("Q shape: ", Q.shape)
print("K shape: ", K.shape)
print("V shape: ", V.shape)
attn = torch.einsum('bhci,bhcj->bhij', Q, K) / (self.head_dim ** 0.5) # (B, heads, H*W, H*W)
attn = torch.softmax(attn, dim=-1) # (B, heads, H*W, H*W)
# Step 3: Apply attention to V
out = torch.einsum('bhij,bhcj->bhci', attn, V) # (B, heads, C_head, H*W)
print("Out Shape 1",out.shape)
out = out.reshape(B, -1, Hq, Wq) # (B, C, H, W) combining heads
print("Out Shape 2",out.shape)
# Step 4: Final projection (combining heads)
out = self.out_conv(out)
return out
Using F.scaled_dot_product_attention
:
class ConvMHAttn(nn.Module):
"""
Convolutional Multi-Head Attention Network
"""
def __init__(self, in_channels, out_channels, num_heads):
super(ConvMHAttn, self).__init__()
assert out_channels % num_heads == 0, "out_channels must be divisible by num_heads"
self.num_heads = num_heads
self.head_dim = out_channels // num_heads # Each head's dimensionality
self.q_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) # Query projection
self.k_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) # Key projection
self.v_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) # Value projection
self.out_conv = nn.Conv2d(out_channels, out_channels, kernel_size=1) # Output projection
def forward(self, query, key, value):
assert query.shape == key.shape, "Query and Key must have the same shape"
assert key.shape == value.shape, "Key and Value must have the same shape"
B, Cq, Hq, Wq = query.shape # Query shape (target input)
B, Ckv, Hkv, Wkv = key.shape # Key-Value shape (source input)
# Step 1: Create Query, Key, and Value
Q = self.q_conv(query).view(B, self.num_heads, self.head_dim, Hq * Wq) # (B, heads, C_head, Hq*Wq)
K = self.k_conv(key).view(B, self.num_heads, self.head_dim, Hkv * Wkv) # (B, heads, C_head, Hkv*Wkv)
V = self.v_conv(value).view(B, self.num_heads, self.head_dim, Hkv * Wkv) # (B, heads, C_head, Hkv*Wkv)
# Step 2: Scaled Dot-Product Attention
out = F.scaled_dot_product_attention(query=Q, key=K, value=V, dropout_p=0.2)
out = out.reshape(B, -1, Hq, Wq) # (B, C, H, W) combining heads
# Step 4: Final projection (combining heads)
out = self.out_conv(out)
return out