Embed_dim error in a Simple Transformer Model

I am creating a simplified version of the transformer model that does not use embeddings, does not use attention masks, and the inputs are simply a sequence of tensors of length 4. Code is based on Sequence-to-Sequence Modeling with nn.Transformer and TorchText — PyTorch Tutorials 1.8.1+cu102 documentation

However, I am getting an AssertionError when the torch module runs a check using assert embed_dim == embed_dim_to_check

Strangely, when I replace

input_batch = torch.randn(3, 4)  # (batch_size, seq_len)

by adding a 3rd dimension and incrementing d_model from 1 to 2,

input_batch = torch.randn(3, 4, 2)  # (batch_size, seq_len, n_features)

the error goes away!

However, my original data just has a sequence of numbers, with a feature size of 1, such as

[1.0, 2.0, 3.0, 4.0]

I can’t understand why this error occurs, hope someone can shed some light on it. Thank you!

import torch
import torch.nn as nn
import math


class Trans(nn.Module):
    def __init__(self, d_model, nhead, nhid, nlayers, dropout=0.5):
        super().__init__()
        self.pos_encoder = PositionalEncoding(d_model, dropout)
        encode_layers = nn.TransformerEncoderLayer(d_model, nhead, nhid, dropout)
        self.transformer_encoder = nn.TransformerEncoder(encode_layers, nlayers)
        self.d_model = d_model

    def forward(self, src):
        src = self.pos_encoder(src)
        out = self.transformer_encoder(src)
        return out

class PositionalEncoding(nn.Module):
    """ Copied straight from the tutorial """
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer("pe", pe)

    def forward(self, x):
        x = x + self.pe[: x.size(0), :]
        return self.dropout(x)
input_batch = torch.randn(3, 4)  # (batch_size, seq_len)
net = Trans(d_model=1, nhead=1, nhid=2048, nlayers=2)
out = net(input_batch)  # AssertionError

input_batch = torch.randn(3, 4, 2)  # (batch_size, seq_len, num_features)
net = Trans(d_model=2, nhead=1, nhid=2048, nlayers=2)
out = net(input_batch)  # NO ERRORS!

Error Traceback

Traceback (most recent call last):
  File "test.py", line 42, in <module>
    out = net(input_batch)
  File "/opt/anaconda3/envs/test/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "test.py", line 16, in forward
    out = self.transformer_encoder(src)
  File "/opt/anaconda3/envs/test/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/opt/anaconda3/envs/test/lib/python3.8/site-packages/torch/nn/modules/transformer.py", line 181, in forward
    output = mod(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask)
  File "/opt/anaconda3/envs/test/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/opt/anaconda3/envs/test/lib/python3.8/site-packages/torch/nn/modules/transformer.py", line 293, in forward
    src2 = self.self_attn(src, src, src, attn_mask=src_mask,
  File "/opt/anaconda3/envs/test/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/opt/anaconda3/envs/test/lib/python3.8/site-packages/torch/nn/modules/activation.py", line 978, in forward
    return F.multi_head_attention_forward(
  File "/opt/anaconda3/envs/test/lib/python3.8/site-packages/torch/nn/functional.py", line 4135, in multi_head_attention_forward
    assert embed_dim == embed_dim_to_check
AssertionError

Even more puzzling is this that if we change input_batch from torch.randn(3, 4, 2)

input_batch = torch.randn(3, 4, 2)  # (batch_size, seq_len, num_features)
net = Trans(d_model=2, nhead=1, nhid=2048, nlayers=2)
out = net(input_batch)  # NO ERRORS!

to torch.randn(3, 4, 3)

input_batch = torch.randn(3, 4, 3)  # (batch_size, seq_len, num_features)
net = Trans(d_model=3, nhead=1, nhid=2048, nlayers=2)
out = net(input_batch)  # RuntimeError!

we now get an error

RuntimeError: The expanded size of the tensor (1) must match the existing size (2) at non-singleton dimension 1. Target sizes: [5000, 1]. Tensor sizes: [5000, 2]

with the error trackback

Traceback (most recent call last):
  File "test.py", line 42, in <module>
    net = Trans(d_model=3, nhead=2, nhid=2048, nlayers=2)
  File "test.py", line 9, in __init__
    self.pos_encoder = PositionalEncoding(d_model, dropout)
  File "test.py", line 31, in __init__
    pe[:, 1::2] = torch.cos(position * div_term)
RuntimeError: The expanded size of the tensor (1) must match the existing size (2) at non-singleton dimension 1.  Target sizes: [5000, 1].  Tensor sizes: [5000, 2]

Switching once more from to torch.randn(3, 4, 3)toto torch.randn(3, 4, 4) solves the error.