I would like to implement cross-layer parameter sharing similar to what’s done in ALBERT, between attention blocks. I basically want to share the same K,Q,V matrices… Can anyone please check this implement and confirm if it does the same?
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class SharedAttention(nn.Module):
def init(self, shared_Wq, shared_Wk, shared_Wv, input_dim, hidden_dim, num_heads):
super(SharedAttention, self).init()
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.num_heads = num_heads
self.shared_Wq = shared_Wq
self.shared_Wk = shared_Wk
self.shared_Wv = shared_Wv
def forward(self, input):
# Linear transformation for Q, K, and V
q = torch.matmul(input , self.shared_Wq)
k = torch.matmul(input , self.shared_Wq)
v = torch.matmul(input , self.shared_Wq)
# Split into multiple heads
q = q.view(input.size(0), -1, self.num_heads, self.hidden_dim // self.num_heads).transpose(1, 2)
k = k.view(input.size(0), -1, self.num_heads, self.hidden_dim // self.num_heads).transpose(1, 2)
v = v.view(input.size(0), -1, self.num_heads, self.hidden_dim // self.num_heads).transpose(1, 2)
# Apply shared attention mechanism
scores = torch.matmul(q , k.transpose(-2, -1)) / math.sqrt(self.hidden_dim)
attention_weights = F.softmax(scores, dim=-1)
output = torch.matmul(attention_weights, v)
# Concatenate and reshape to original shape
output = output.transpose(1, 2).contiguous().view(input.size(0), -1, self.hidden_dim)
return output
class Model(nn.Module):
def init(self, shared_Wq, shared_Wk, shared_Wv, input_dim, hidden_dim, output_dim, num_heads):
super(Model, self).init()
self.attention1 = SharedAttention(shared_Wq, shared_Wk, shared_Wv, input_dim, hidden_dim, num_heads)
self.attention2 = SharedAttention(shared_Wq, shared_Wk, shared_Wv, input_dim, hidden_dim, num_heads)
self.fc = nn.Linear(hidden_dim, output_dim)
def forward(self, input):
output1 = self.attention1(input)
output2 = self.attention2(output1)
output = self.fc(output2)
return output
input_dim = 6
hidden_dim = 12
output_dim = 1
num_heads = 2
shared_Wq = nn.Parameter(torch.randn(input_dim, num_heads * hidden_dim))
shared_Wk = nn.Parameter(torch.randn(input_dim, num_heads * hidden_dim))
shared_Wv = nn.Parameter(torch.randn(input_dim, num_heads * hidden_dim))
Instantiate the model with shared weights
model = Model(shared_Wq, shared_Wk, shared_Wv, input_dim=input_dim, hidden_dim=hidden_dim, output_dim=output_dim, num_heads=2)