Shared parameters

I would like to implement cross-layer parameter sharing similar to what’s done in ALBERT, between attention blocks. I basically want to share the same K,Q,V matrices… Can anyone please check this implement and confirm if it does the same?

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class SharedAttention(nn.Module):
def init(self, shared_Wq, shared_Wk, shared_Wv, input_dim, hidden_dim, num_heads):
super(SharedAttention, self).init()
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.num_heads = num_heads

    self.shared_Wq = shared_Wq
    self.shared_Wk = shared_Wk
    self.shared_Wv = shared_Wv

def forward(self, input):
    # Linear transformation for Q, K, and V
    q = torch.matmul(input , self.shared_Wq)
    k = torch.matmul(input , self.shared_Wq)
    v = torch.matmul(input , self.shared_Wq)
    
    # Split into multiple heads
    q = q.view(input.size(0), -1, self.num_heads, self.hidden_dim // self.num_heads).transpose(1, 2)
    k = k.view(input.size(0), -1, self.num_heads, self.hidden_dim // self.num_heads).transpose(1, 2)
    v = v.view(input.size(0), -1, self.num_heads, self.hidden_dim // self.num_heads).transpose(1, 2)
    
    # Apply shared attention mechanism
    scores = torch.matmul(q , k.transpose(-2, -1)) / math.sqrt(self.hidden_dim)
    attention_weights = F.softmax(scores, dim=-1)
    output = torch.matmul(attention_weights, v)

    # Concatenate and reshape to original shape
    output = output.transpose(1, 2).contiguous().view(input.size(0), -1, self.hidden_dim)
    
    return output

class Model(nn.Module):
def init(self, shared_Wq, shared_Wk, shared_Wv, input_dim, hidden_dim, output_dim, num_heads):
super(Model, self).init()
self.attention1 = SharedAttention(shared_Wq, shared_Wk, shared_Wv, input_dim, hidden_dim, num_heads)
self.attention2 = SharedAttention(shared_Wq, shared_Wk, shared_Wv, input_dim, hidden_dim, num_heads)
self.fc = nn.Linear(hidden_dim, output_dim)

def forward(self, input):
    output1 = self.attention1(input)
    output2 = self.attention2(output1)
    output = self.fc(output2)
    return output

input_dim = 6
hidden_dim = 12
output_dim = 1
num_heads = 2
shared_Wq = nn.Parameter(torch.randn(input_dim, num_heads * hidden_dim))
shared_Wk = nn.Parameter(torch.randn(input_dim, num_heads * hidden_dim))
shared_Wv = nn.Parameter(torch.randn(input_dim, num_heads * hidden_dim))

Instantiate the model with shared weights

model = Model(shared_Wq, shared_Wk, shared_Wv, input_dim=input_dim, hidden_dim=hidden_dim, output_dim=output_dim, num_heads=2)