I’ve been working on MAE model and I’m struglling with the attention score.
I checked the trained model and saw the attention score but all the value was same( 1/ token_num )
Do someone know what is happening here?
here is the code:
class MultiHeadAttention(nn.Module):
def __init__(self, dim, n_heads):
super().__init__()
self.n_heads = n_heads
self.dim_heads = dim // n_heads
self.W_q = nn.Linear(dim, dim)
self.W_k = nn.Linear(dim, dim)
self.W_v = nn.Linear(dim, dim)
self.split_into_heads = Rearrange("b n (h d) -> b h n d", h = self.n_heads)
self.softmax = nn.Softmax(dim = -1)
self.concat = Rearrange("b h n d -> b n (h d)", h = self.n_heads)
def forward(self, x):
q = self.W_q(x)
k = self.W_k(x)
v = self.W_v(x)
q = self.split_into_heads(q)
k = self.split_into_heads(k)
v = self.split_into_heads(v)
logit = torch.matmul(q, k.transpose(-1, -2)) * (self.dim_heads ** -0.5)
attention_weight = self.softmax(logit)
output = torch.matmul(attention_weight, v)
output = self.concat(output)
return output,attention_weight
class Multimodal_Transformer_encoder(nn.Module):
def __init__(self,Multimodal_dim,Multimodal_head,Multimodal_hidden_dimension):
super(Multimodal_Transformer_encoder,self).__init__()
self.Multimodal_head=Multimodal_head
self.Multimodal_hidden_dimension=Multimodal_hidden_dimension
self.Multimodal_dim=Multimodal_dim
self.MLSA_norm=nn.LayerNorm(self.Multimodal_dim)
self.MLSA=MultiHeadAttention(self.Multimodal_dim,self.Multimodal_head)
self.MLP=nn.Sequential(
nn.LayerNorm(self.Multimodal_dim),
nn.Linear(self.Multimodal_dim, self.Multimodal_hidden_dimension),
nn.GELU(),
nn.Linear(self.Multimodal_hidden_dimension, self.Multimodal_dim),
)
self.MLP_norm=nn.LayerNorm(self.Multimodal_dim)
def forward(self,x):
x_normed=self.MLSA_norm(x)
self.output_MLSA , self.attn_weights=self.MLSA(x_normed)
x=x+self.output_MLSA
x_second_norm=self.MLP_norm(x)
self.MLP_output=self.MLP(x_second_norm)
x=x+self.MLP_output
return x,self.attn_weights
class M3AE_encoder(nn.Module):
def __init__(self,Model_params):
super(M3AE_encoder,self).__init__()
self.Model_params=Model_params
self.encoder_list=nn.ModuleList()
for layer_num in range(Model_params["encoder_Layers"]):
self.encoder_list+=[Multimodal_Transformer_encoder(Model_params["encoder_hidden_size"],Model_params["encoder_heads"],Model_params["encoder_MLP_size"])]
def forward(self,x):
#x: (batch,token_dimension)
atten_list=[]
for layer_num in range(self.Model_params["encoder_Layers"]):
x,atten_weights=self.encoder_list[layer_num](x)
atten_list+=[atten_weights]
return x,atten_list
class M3AE_decoder(nn.Module):
def __init__(self,Model_params):
super(M3AE_decoder,self).__init__()
self.Model_params=Model_params
self.decoder_list=nn.ModuleList()
for layer_num in range(Model_params["decoder_Layers"]):
self.decoder_list+=[Multimodal_Transformer_encoder(Model_params["decoder_hidden_size"],Model_params["decoder_heads"],Model_params["decoder_MLP_size"])]
def forward(self,x):
atten_list=[]
for layer_num in range(self.Model_params["decoder_Layers"]):
x,atten_weights=self.decoder_list[layer_num](x)
atten_list+=[atten_weights]
return x, atten_list
class M3AE(nn.Module):
def __init__(self,Model_params):
super(M3AE,self).__init__()
self.Model_params=Model_params
self.Total_patches=Model_params["image_patches_num"]+Model_params["language_number"]
self.encoder_hidden_size=Model_params["encoder_hidden_size"]
self.dimension_per_patch=Model_params["image_patches_dim"]*Model_params["image_num_per_patch"]
#--------------------------------encoder structure-------------------------------------
self.sinual_positional_encoding=self.positional_encoding(Model_params["encoder_hidden_size"])
self.image_initial_projection=nn.Linear(self.dimension_per_patch,Model_params["encoder_hidden_size"])
self.language_initial_projection=nn.Linear(1,Model_params["encoder_hidden_size"])
self.M3AE_encoder=M3AE_encoder(Model_params)
self.sinual_positional_encoding=self.positional_encoding(Model_params["encoder_hidden_size"])
self.mask_ratio=Model_params["mask_ratio"]
self.encoder_image_modality=nn.Parameter(torch.randn(1,1,Model_params["encoder_hidden_size"]))
self.encoder_language_modality=nn.Parameter(torch.randn(1,1,Model_params["encoder_hidden_size"]))
#--------------------------------decoder structure-------------------------------------
self.M3AE_decoder=M3AE_decoder(Model_params)
self.decoder_sinual_positional_encoding=self.positional_encoding(Model_params["decoder_hidden_size"])
self.encoder_decoder_projection=nn.Linear(Model_params["encoder_hidden_size"],Model_params["decoder_hidden_size"])
self.decoder_image_modality=nn.Parameter(torch.randn(1,1,Model_params["decoder_hidden_size"]))
self.decoder_language_modality=nn.Parameter(torch.randn(1,1,Model_params["decoder_hidden_size"]))
self.masked_token=nn.Parameter(torch.randn(1,1,Model_params["decoder_hidden_size"]))
#------------------------------output_structure-----------------------------------------
self.image_output_linear=nn.Sequential(
nn.Linear(Model_params["decoder_hidden_size"], Model_params["image_num_per_patch"]*Model_params["image_patches_dim"])
)
self.language_output_linear=nn.Sequential(
nn.Linear(Model_params["decoder_hidden_size"], 1)
)
def forward(self,xt,xv,Mode="Train"):
#xt: (batch,patch,48) xv: (batch,token_num)
#-----------embedding---------------------#
xv=xv.unsqueeze(2)
embeded_t=self.image_initial_projection(xt)
embeded_v=self.language_initial_projection(xv)
embeded_t=embeded_t+self.encoder_image_modality
embeded_v=embeded_v+self.encoder_language_modality
x_concat=torch.cat([embeded_t,embeded_v],dim=1)
x_concat=x_concat+self.sinual_positional_encoding
#----------random masking for encoder-----#
if Mode=="Train":
mask_list=self.mask_index_random()
x_concat=x_concat[:,mask_list,:]
#---------random_encoder_output-----------#
latent_data,encoder_atten=self.M3AE_encoder(x_concat)
decoder_input=self.encoder_decoder_projection(latent_data)
if Mode=="Train":
decoder_token_list=[]
for patch_num in range(self.Total_patches):
if patch_num in mask_list:
decoder_token_list+=[decoder_input[:,mask_list.index(patch_num),:].unsqueeze(1)]
else:
decoder_token_list+=[self.masked_token.expand(decoder_input.shape[0],-1,-1)]
decoder_input=torch.cat(decoder_token_list,dim=1)
decoder_input[:,:self.Model_params["image_patches_num"],:]=decoder_input[:,:self.Model_params["image_patches_num"],:]+self.decoder_image_modality
decoder_input[:,-1*self.Model_params["language_number"]:,:]=decoder_input[:,-1*self.Model_params["language_number"]:,:]+self.decoder_language_modality
decoder_input=decoder_input+self.decoder_sinual_positional_encoding
decoder_output,decoder_atten=self.M3AE_decoder(decoder_input)
image_output=self.image_output_linear(decoder_output[:,:self.Model_params["image_patches_num"],:])
language_output=self.language_output_linear(decoder_output[:,-1*self.Model_params["language_number"]:,:])
return latent_data,[image_output,language_output],mask_list,[encoder_atten,decoder_atten]
else:
return latent_data,encoder_atten
def get_angles(self, pos, i,hidden_dimension):
angle_rates = 1 / np.power(10000, (2 * (i // 2)) / np.float32(hidden_dimension))
return pos * angle_rates
def positional_encoding(self,hidden_dimension):
angle_rads = self.get_angles(np.arange(self.Total_patches)[:, np.newaxis], np.arange(hidden_dimension)[np.newaxis, :],hidden_dimension)
# apply sin to even indices in the array; 2i
angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2])
# apply cos to odd indices in the array; 2i+1
angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2])
pos_encoding = angle_rads[np.newaxis, :, :]
return torch.tensor(pos_encoding).float().to(device=0)
def mask_index_random(self):
mask_list=[]
for index in range(self.Total_patches):
if random.random()>self.mask_ratio:
mask_list+=[index]
return mask_list
I alse used “nn.MultiheadAttention” module but same thing happened