A word's embedding from nn.TransformerEncoder is same in different sentences

I assume the sentence “src” have 4 words, and each word is embedding into a 16 dim vector.
code like src = torch.randn(1, 4, 16)
The first word remains unchanged, the following three words are randomly generated. you can see in the output from nn.TransformerEncoder with code encode_layer = nn.TransformerEncodeLayer(d_model=16, nhead=1, dropout=0) transformer_encoder=nn.Transformer(encoder_layer, num_layers=1). No matter how the following words change, the embedding of the first word is the same.
I think this result is unreasonable, or is there a problem with the way I use?

>>> import torch                                                                                                                                                                                                                               │1       gcc-x64-O1  0.610533  0.617933  0.947600  0.985700
>>> import torch.nn as nn                                                                                                                                                                                                                      │2       gcc-x64-O2  0.534200  0.540700  0.918633  0.977433
>>> src = torch.randn(1,4,16)                                                                                                                                                                                                                  │3       gcc-x64-O3  0.521033  0.527467  0.911867  0.974067
>>> src                                                                                                                                                                                                                                        │4       gcc-arm-O0  0.784367  0.794967  0.982433  0.993367
tensor([[[-1.2324,  1.0855, -0.1256, -1.0511, -0.9130, -0.9795,  2.1580,                                                                                                                                                                       │5       gcc-arm-O1  0.515767  0.522133  0.925933  0.978433
           0.7377,  0.9881, -0.5899, -0.7326, -0.3581,  0.1045,  0.0146,                                                                                                                                                                       │6       gcc-arm-O2  0.455333  0.461000  0.899667  0.967267
           0.9395, -1.8705],                                                                                                                                                                                                                   │7       gcc-arm-O3  0.442933  0.448433  0.892100  0.964433
         [-0.2338, -1.1813,  0.6624,  0.4502, -0.8712,  1.5969, -0.1144,                                                                                                                                                                       │8     gcc-arm64-O0  0.880600  0.892833  0.989200  0.996600
           1.6774,  2.0996, -0.8762,  0.3418,  0.8612,  0.2751,  0.7086,                                                                                                                                                                       │9     gcc-arm64-O1  0.563933  0.570233  0.937133  0.981833
           0.1442, -0.1141],                                                                                                                                                                                                                   │10    gcc-arm64-O2  0.489733  0.495467  0.903767  0.969867
         [-1.6453,  0.3087, -0.7955, -0.4064, -0.4819,  0.0947, -1.3674,                                                                                                                                                                       │11    gcc-arm64-O3  0.476933  0.482400  0.896533  0.966433
          -1.4710,  1.5510, -0.4081,  0.4080,  1.5987,  0.6879,  1.0894,                                                                                                                                                                       │12    clang-x64-O0  0.832033  0.843267  0.979033  0.991200
           0.1054, -0.3860],                                                                                                                                                                                                                   │13    clang-x64-O1  0.551233  0.557600  0.926000  0.977367
         [ 1.4328, -0.3503, -0.8012, -0.2605, -0.2735,  0.4599, -1.1251,                                                                                                                                                                       │14    clang-x64-O2  0.537433  0.543600  0.921500  0.975667
           0.9143, -0.5636, -0.3161, -1.9138, -0.4677,  0.9299,  0.0075,                                                                                                                                                                       │15    clang-x64-O3  0.530500  0.536633  0.918333  0.974800
           0.9642,  0.6295]]])                                                                                                                                                                                                                 │16    clang-arm-O0  0.690633  0.699867  0.960400  0.983433
>>> t = src[0][0]
>>> encode_layer = nn.TransformerEncoderLayer(d_model=16, nhead=1, dropout=0)
>>> transformer_encoder = nn.TransformerEncoder(encode_layer, num_layers=1)
>>> out = transformer_encoder(src)                                                                                                                                                                                                             │20  clang-arm64-O0  0.702133  0.711600  0.953467  0.982133
>>> out                                                                                                                                                                                                                                        │21  clang-arm64-O1  0.509900  0.515867  0.913267  0.972933
tensor([[[-1.6914,  0.9598, -0.3110, -0.9993, -0.9691,  0.0186,  1.6132,                                                                                                                                                                       │22  clang-arm64-O2  0.496633  0.502500  0.909067  0.971600
           0.5669,  1.3026,  0.0505,  0.0253,  0.0407,  0.8174, -0.0248,                                                                                                                                                                       │23  clang-arm64-O3  0.492667  0.498600  0.905533  0.970800
           0.6984, -2.0979],                                                                                                                                                                                                                   │            top1   dup_top1     top100     top500
         [-1.0822, -1.7807,  0.3079,  0.3645, -0.7527,  1.7316, -0.4070,                                                                                                                                                                       │count  24.000000  24.000000  24.000000  24.000000
           1.4846,  1.6255, -1.6291,  0.2163,  0.1240, -0.0125, -0.1307,                                                                                                                                                                       │mean    0.582917   0.590207   0.929579   0.977606
           0.2133, -0.2728],                                                                                                                                                                                                                   │std     0.150039   0.152522   0.031837   0.009895
         [-1.7741,  0.4637, -0.5941, -0.5148, -0.0103, -0.2896, -1.6066,                                                                                                                                                                       │min     0.442933   0.448433   0.892100   0.964433
          -0.8462,  2.1860, -0.4621,  0.8168,  1.2694,  0.6117,  0.5622,                                                                                                                                                                       │25%     0.486533   0.492200   0.905883   0.970317
           0.6685, -0.4805],                                                                                                                                                                                                                   │50%     0.525767   0.532050   0.918483   0.975233
         [ 1.0945, -0.3281, -1.3840, -0.0261, -0.0547,  0.9100, -1.4425,                                                                                                                                                                       │75%     0.630558   0.638417   0.949067   0.982458
           1.1332, -1.0016, -0.9842, -1.8533,  0.8731,  0.8267,  0.5619,                                                                                                                                                                       │max     0.985700   1.000000   1.000000   1.000000
           0.7640,  0.9113]]], grad_fn=<NativeLayerNormBackward>)
>>> src = torch.randn(1,4,16)
>>> src[0][0] = t
>>> src
tensor([[[-1.2324,  1.0855, -0.1256, -1.0511, -0.9130, -0.9795,  2.1580,
           0.7377,  0.9881, -0.5899, -0.7326, -0.3581,  0.1045,  0.0146,
           0.9395, -1.8705],
         [-0.0318, -0.9081,  0.0209, -0.0474, -0.6776,  1.0178,  0.1818,
           0.3425, -0.3291,  0.5315, -1.1100, -0.0805, -2.8687, -0.1846,
           0.4410, -1.3325],
         [-0.3382,  0.1080, -1.1487,  0.9970, -1.3720, -0.7435, -0.6814,
          -0.3464,  1.2251,  1.0221,  0.2239, -1.7205,  0.9573, -0.1409,
          -0.2924,  0.0602],
         [-1.5206,  0.3053,  1.0292,  0.4906, -1.1222,  0.7273, -1.3435,
           0.3213, -0.6371,  1.2545,  1.5598, -0.2372,  0.0758,  0.4457,
           0.9007, -0.2775]]])
>>> out = transformer_encoder(src)
>>> out
tensor([[[-1.6914,  0.9598, -0.3110, -0.9993, -0.9691,  0.0186,  1.6132,
           0.5669,  1.3026,  0.0505,  0.0253,  0.0407,  0.8174, -0.0248,
           0.6984, -2.0979],
         [-0.1582, -0.9467,  0.6978,  0.8175, -0.8259,  2.1026, -0.3020,
           0.4182, -0.3150,  0.2300,  0.1154,  0.3126, -2.4654,  0.5631,
           0.8439, -1.0878],
         [-1.0309,  0.3078, -2.1636,  0.8545, -1.1014, -0.0249, -0.0669,
          -0.1179,  1.2710,  1.6001, -0.4917, -0.7905,  1.4325,  0.4466,
          -0.7322,  0.6074],
         [-1.8881, -0.2105,  0.9825,  0.4861, -0.8173,  0.9925, -1.7599,
           0.3325, -1.2461,  0.2011,  1.7443, -0.2333,  0.5631,  0.7407,
           0.6749, -0.5626]]], grad_fn=<NativeLayerNormBackward>)

does anyone can help me, I think a word’s embedding should be related with other words in the transformer

Could you please post your code by wrapping it into three backticks ``` and explain your use case? :slight_smile:
In particular, what is the code doing, what are you expecting, and what goes wrong.

Hi, I have changed the description, could you help me?

Thanks for the code update.
Could you explain, why the first output should be different, if you pass in the same tensor?
If I’m not mistaken the submodules are mutliheadattention as well as some linear layers and I’m not sure, why they should yield a different output. However, not that I haven’t looked into the implementation deeply.

Not the same tensor, but the first element of the tensor is same. It’s like the first word of two sentences is the same, for example, “I have an apple” and “I am not sure”. In my understanding of the transformer, the embedding of the same word in different sentences should be different, because there is self-attention, the embedding of a word is the weighted sum of the values of all words in the sentence. But in my example, the first word ([-1.2324, 1.0855, -0.1256, -1.0511, -0.9130, -0.9795, 2.1580, 0.515767, 0.522133, 0.925933, 0.978433, 0.7377, 0.9881, -0.5899, -0.7326, -0.3581, 0.1045, 0.0146, 0.455333, 0.461000, 0.899667, 0.967267, 0.9395, -1.8705]) got the same embedding: [-1.6914, 0.9598, -0.3110, -0.9993, -0.9691, 0.0186, 1.6132, 0.5669, 1.3026, 0.0505, 0.0253, 0.0407, 0.8174, -0.0248, 0.6984, -2.0979]

my usage is wrong, I should let src = torch.randn(4,1,16), because the nn.TransformerEncoder’s input shape is (seq_len, batch_size, embedding_dim)