Assertion error when using TransformerDecoder

I am trying to make a AutoEncoder style model using Transformer in pytorch. My Encoder part of model predicts certain values. These values are passed to Decoder which predicts back the input. Model is updated on loss of both outputs. I am using a TransformerEncoder in the Encoder part. Then passing its output to a linear layer which gives o/p os size [batch,seqsize,3]. this acts as input for decoder. output of TransformerEncoder is of size [batch,seqsize,300]. I am passing the output as target and the output of TransformerEncoder as memory. But I get error on forward pass to Decoder

---------------------------------------------------------------------------

AssertionError                            Traceback (most recent call last)

<ipython-input-25-958b5ce5b42a> in <module>()
----> 1 o1 = decoder(o,mem)

8 frames

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    720             result = self._slow_forward(*input, **kwargs)
    721         else:
--> 722             result = self.forward(*input, **kwargs)
    723         for hook in itertools.chain(
    724                 _global_forward_hooks.values(),

<ipython-input-17-c7ce79c8fbe4> in forward(self, x, mem)
      8 
      9   def forward(self, x, mem):
---> 10     out = self.transformer_dec(x, mem)
     11     out = self.linear1(out)
     12     out = torch.softmax(out)

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    720             result = self._slow_forward(*input, **kwargs)
    721         else:
--> 722             result = self.forward(*input, **kwargs)
    723         for hook in itertools.chain(
    724                 _global_forward_hooks.values(),

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/transformer.py in forward(self, tgt, memory, tgt_mask, memory_mask, tgt_key_padding_mask, memory_key_padding_mask)
    232                          memory_mask=memory_mask,
    233                          tgt_key_padding_mask=tgt_key_padding_mask,
--> 234                          memory_key_padding_mask=memory_key_padding_mask)
    235 
    236         if self.norm is not None:

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    720             result = self._slow_forward(*input, **kwargs)
    721         else:
--> 722             result = self.forward(*input, **kwargs)
    723         for hook in itertools.chain(
    724                 _global_forward_hooks.values(),

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/transformer.py in forward(self, tgt, memory, tgt_mask, memory_mask, tgt_key_padding_mask, memory_key_padding_mask)
    362         """
    363         tgt2 = self.self_attn(tgt, tgt, tgt, attn_mask=tgt_mask,
--> 364                               key_padding_mask=tgt_key_padding_mask)[0]
    365         tgt = tgt + self.dropout1(tgt2)
    366         tgt = self.norm1(tgt)

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    720             result = self._slow_forward(*input, **kwargs)
    721         else:
--> 722             result = self.forward(*input, **kwargs)
    723         for hook in itertools.chain(
    724                 _global_forward_hooks.values(),

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/activation.py in forward(self, query, key, value, key_padding_mask, need_weights, attn_mask)
    925                 training=self.training,
    926                 key_padding_mask=key_padding_mask, need_weights=need_weights,
--> 927                 attn_mask=attn_mask)
    928 
    929 

/usr/local/lib/python3.6/dist-packages/torch/nn/functional.py in multi_head_attention_forward(query, key, value, embed_dim_to_check, num_heads, in_proj_weight, in_proj_bias, bias_k, bias_v, add_zero_attn, dropout_p, out_proj_weight, out_proj_bias, training, key_padding_mask, need_weights, attn_mask, use_separate_proj_weight, q_proj_weight, k_proj_weight, v_proj_weight, static_k, static_v)
   3947                 v_proj_weight=v_proj_weight, static_k=static_k, static_v=static_v)
   3948     tgt_len, bsz, embed_dim = query.size()
-> 3949     assert embed_dim == embed_dim_to_check
   3950     # allow MHA to have different sizes for the feature dimension
   3951     assert key.size(0) == value.size(0) and key.size(1) == value.size(1)

AssertionError: 

Why am I getting this? what am I doing wrong? I couldn’t find any examples of TransformerDecoder to check if my approach is correct
This is the Encoder part:

class ColorTransformerEncoder(nn.Module):
  def __init__(self, input_size, num_head, hidden_size, num_layers):
    
    super(ColorTransformerEncoder, self).__init__()

    self.embd = nn.Embedding(word_count,input_size)
    encoder_layer = nn.TransformerEncoderLayer(input_size, num_head, hidden_size)
    self.transformer_enc = nn.TransformerEncoder(encoder_layer, num_layers)
    self.linear1 = nn.Linear(input_size,3)

  def forward(self, x):
    x = x.long()
    emb = self.embd(x)
    mem = self.transformer_enc(emb)
    out = self.linear1(mem)
    out = torch.sigmoid(out)

    return out, mem

This is the decoder part:

class ColorTransformerDecoder(nn.Module):
  def __init__(self, input_size, num_head, output_size, hidden_size, num_layers):
    super(ColorTransformerDecoder, self).__init__()
    
    decoder_layer = nn.TransformerDecoderLayer(input_size, num_head, hidden_size)
    self.transformer_dec = nn.TransformerDecoder(decoder_layer, num_layers)
    self.linear1 = nn.Linear(input_size, output_size)

  def forward(self, x, mem):
    out = self.transformer_dec(x, mem)
    out = self.linear1(out)
    out = torch.softmax(out)

    return out

this is model initialization

num_head = 2
num_layers = 2
hidden_size = 200
emb_size = 300
encoder = ColorTransformerEncoder(emb_size, num_head, hidden_size, num_layers).to(device)
decoder = ColorTransformerDecoder(emb_size, num_head, word_count, hidden_size, num_layers).to(device)

This is the testing code that is giving error

o,mem = encoder(torch.zeros((1,6)).to(device))
o1 = decoder(o,mem)

The error occurs in the second statement of the above code

Please share your full code here (both model and training part).

I added my code. pls check it

Two things:

  1. Make sure that the input dimensions of o and mem are the same. So if you change this line:

to o1 = decoder(mem, mem) it would eliminate the error. This may not be what you want but shows that tgt and memory should have the same dimensions.
Also, we usually use a nn.Linear layer to fine-tune the model for our specific task. In that case, you may not need a separate decoder.

You’ll need to change this to something like out = torch.softmax(out, dim=-1)