My Transformer encoder attention score is all same value

I’ve been working on MAE model and I’m struglling with the attention score.
I checked the trained model and saw the attention score but all the value was same( 1/ token_num )

Do someone know what is happening here?

here is the code:

class MultiHeadAttention(nn.Module):
    def __init__(self, dim, n_heads):

        super().__init__()
        self.n_heads = n_heads
        self.dim_heads = dim // n_heads

        self.W_q = nn.Linear(dim, dim)
        self.W_k = nn.Linear(dim, dim)
        self.W_v = nn.Linear(dim, dim)

        self.split_into_heads = Rearrange("b n (h d) -> b h n d", h = self.n_heads)

        self.softmax = nn.Softmax(dim = -1)

        self.concat = Rearrange("b h n d -> b n (h d)", h = self.n_heads)

    def forward(self, x):

        q = self.W_q(x)
        k = self.W_k(x)
        v = self.W_v(x)

        q = self.split_into_heads(q)
        k = self.split_into_heads(k)
        v = self.split_into_heads(v)

        logit = torch.matmul(q, k.transpose(-1, -2)) * (self.dim_heads ** -0.5)
        attention_weight = self.softmax(logit)

        output = torch.matmul(attention_weight, v)
        output = self.concat(output)
        return output,attention_weight


class Multimodal_Transformer_encoder(nn.Module):
    def __init__(self,Multimodal_dim,Multimodal_head,Multimodal_hidden_dimension):
        super(Multimodal_Transformer_encoder,self).__init__()
        
        self.Multimodal_head=Multimodal_head
        self.Multimodal_hidden_dimension=Multimodal_hidden_dimension
        self.Multimodal_dim=Multimodal_dim
      

        self.MLSA_norm=nn.LayerNorm(self.Multimodal_dim)
        self.MLSA=MultiHeadAttention(self.Multimodal_dim,self.Multimodal_head)
        
        

        self.MLP=nn.Sequential(
            nn.LayerNorm(self.Multimodal_dim),
            nn.Linear(self.Multimodal_dim, self.Multimodal_hidden_dimension),
            nn.GELU(),
            nn.Linear(self.Multimodal_hidden_dimension, self.Multimodal_dim),
        )

        self.MLP_norm=nn.LayerNorm(self.Multimodal_dim)



    def forward(self,x):
        x_normed=self.MLSA_norm(x)
        self.output_MLSA , self.attn_weights=self.MLSA(x_normed)
        x=x+self.output_MLSA
    
        x_second_norm=self.MLP_norm(x)
        self.MLP_output=self.MLP(x_second_norm)
        x=x+self.MLP_output
        
        return x,self.attn_weights


class M3AE_encoder(nn.Module):
    def __init__(self,Model_params):
        super(M3AE_encoder,self).__init__()
        self.Model_params=Model_params

        self.encoder_list=nn.ModuleList()
        for layer_num in range(Model_params["encoder_Layers"]):
            self.encoder_list+=[Multimodal_Transformer_encoder(Model_params["encoder_hidden_size"],Model_params["encoder_heads"],Model_params["encoder_MLP_size"])]



    def forward(self,x):
        #x: (batch,token_dimension)
        atten_list=[]
        
        for layer_num in range(self.Model_params["encoder_Layers"]):
            x,atten_weights=self.encoder_list[layer_num](x)
            atten_list+=[atten_weights]

        return x,atten_list


class M3AE_decoder(nn.Module):
    def __init__(self,Model_params):
        super(M3AE_decoder,self).__init__()
        self.Model_params=Model_params

        self.decoder_list=nn.ModuleList()
        for layer_num in range(Model_params["decoder_Layers"]):
            self.decoder_list+=[Multimodal_Transformer_encoder(Model_params["decoder_hidden_size"],Model_params["decoder_heads"],Model_params["decoder_MLP_size"])]

    def forward(self,x):

        atten_list=[]
        for layer_num in range(self.Model_params["decoder_Layers"]):
            x,atten_weights=self.decoder_list[layer_num](x)
            atten_list+=[atten_weights]
        return x, atten_list

class M3AE(nn.Module):
    def __init__(self,Model_params):
        super(M3AE,self).__init__()
        self.Model_params=Model_params
        self.Total_patches=Model_params["image_patches_num"]+Model_params["language_number"]
        
        
        self.encoder_hidden_size=Model_params["encoder_hidden_size"]
        self.dimension_per_patch=Model_params["image_patches_dim"]*Model_params["image_num_per_patch"]


        #--------------------------------encoder structure-------------------------------------

        self.sinual_positional_encoding=self.positional_encoding(Model_params["encoder_hidden_size"])
        self.image_initial_projection=nn.Linear(self.dimension_per_patch,Model_params["encoder_hidden_size"])
        self.language_initial_projection=nn.Linear(1,Model_params["encoder_hidden_size"])


        self.M3AE_encoder=M3AE_encoder(Model_params)

        self.sinual_positional_encoding=self.positional_encoding(Model_params["encoder_hidden_size"])

        self.mask_ratio=Model_params["mask_ratio"]

        self.encoder_image_modality=nn.Parameter(torch.randn(1,1,Model_params["encoder_hidden_size"]))
        self.encoder_language_modality=nn.Parameter(torch.randn(1,1,Model_params["encoder_hidden_size"]))

        #--------------------------------decoder structure-------------------------------------

        self.M3AE_decoder=M3AE_decoder(Model_params)
      
        self.decoder_sinual_positional_encoding=self.positional_encoding(Model_params["decoder_hidden_size"])
        
        self.encoder_decoder_projection=nn.Linear(Model_params["encoder_hidden_size"],Model_params["decoder_hidden_size"])
        
        self.decoder_image_modality=nn.Parameter(torch.randn(1,1,Model_params["decoder_hidden_size"]))
        self.decoder_language_modality=nn.Parameter(torch.randn(1,1,Model_params["decoder_hidden_size"]))

        self.masked_token=nn.Parameter(torch.randn(1,1,Model_params["decoder_hidden_size"]))

        #------------------------------output_structure-----------------------------------------

        self.image_output_linear=nn.Sequential(
            nn.Linear(Model_params["decoder_hidden_size"], Model_params["image_num_per_patch"]*Model_params["image_patches_dim"])
        )


        self.language_output_linear=nn.Sequential(
            nn.Linear(Model_params["decoder_hidden_size"], 1)
        )


    def forward(self,xt,xv,Mode="Train"):
        #xt: (batch,patch,48)    xv: (batch,token_num)
        
        #-----------embedding---------------------#
        xv=xv.unsqueeze(2)
        embeded_t=self.image_initial_projection(xt)
        embeded_v=self.language_initial_projection(xv)
        embeded_t=embeded_t+self.encoder_image_modality
        embeded_v=embeded_v+self.encoder_language_modality
        x_concat=torch.cat([embeded_t,embeded_v],dim=1)
        x_concat=x_concat+self.sinual_positional_encoding

        #----------random masking for encoder-----#
      
        if Mode=="Train":
            mask_list=self.mask_index_random()
            x_concat=x_concat[:,mask_list,:]
        
        #---------random_encoder_output-----------#
        latent_data,encoder_atten=self.M3AE_encoder(x_concat)

        decoder_input=self.encoder_decoder_projection(latent_data)

        
        if Mode=="Train":
            decoder_token_list=[]
            for patch_num in range(self.Total_patches):
                if patch_num in mask_list:
                    decoder_token_list+=[decoder_input[:,mask_list.index(patch_num),:].unsqueeze(1)]
                 
                else:
                    decoder_token_list+=[self.masked_token.expand(decoder_input.shape[0],-1,-1)]
            
            decoder_input=torch.cat(decoder_token_list,dim=1)
          
            decoder_input[:,:self.Model_params["image_patches_num"],:]=decoder_input[:,:self.Model_params["image_patches_num"],:]+self.decoder_image_modality
            decoder_input[:,-1*self.Model_params["language_number"]:,:]=decoder_input[:,-1*self.Model_params["language_number"]:,:]+self.decoder_language_modality
            
            decoder_input=decoder_input+self.decoder_sinual_positional_encoding
            decoder_output,decoder_atten=self.M3AE_decoder(decoder_input)

            image_output=self.image_output_linear(decoder_output[:,:self.Model_params["image_patches_num"],:])
            language_output=self.language_output_linear(decoder_output[:,-1*self.Model_params["language_number"]:,:])

            return latent_data,[image_output,language_output],mask_list,[encoder_atten,decoder_atten]

        else:
            return latent_data,encoder_atten

    def get_angles(self, pos, i,hidden_dimension):
        angle_rates = 1 / np.power(10000, (2 * (i // 2)) / np.float32(hidden_dimension))
        return pos * angle_rates

    def positional_encoding(self,hidden_dimension):
      
        angle_rads = self.get_angles(np.arange(self.Total_patches)[:, np.newaxis], np.arange(hidden_dimension)[np.newaxis, :],hidden_dimension)

        # apply sin to even indices in the array; 2i
        angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2])

        # apply cos to odd indices in the array; 2i+1
        angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2])

        pos_encoding = angle_rads[np.newaxis, :, :]

        return torch.tensor(pos_encoding).float().to(device=0)



    def mask_index_random(self):
        mask_list=[]
        for index in range(self.Total_patches):
            if random.random()>self.mask_ratio:
                mask_list+=[index]

        return mask_list

I alse used “nn.MultiheadAttention” module but same thing happened

I’m running into this same issue–were you able to figure it out?