Using different feature size between source and target nn.Transformer

I’m trying to train a Transformer model with source and target sequences having feature vectors of different sizes.
As the default Transformer use rely on same size feature vectors, I use a custom encoder.

Here’s a minimal reproducible example:

import torch
import torch.nn as nn

src = torch.rand((10, 32, 256))
tgt = torch.rand((20, 32, 512))

encoder_layer = nn.TransformerEncoderLayer(256, 8, 512, 0.5)
encoder_norm = nn.LayerNorm(256)
encoder = nn.TransformerEncoder(encoder_layer, 8, encoder_norm)

transformer_model = nn.Transformer(nhead=16, num_encoder_layers=12, custom_encoder = encoder)

out = transformer_model(src, tgt)

I get the following error message:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-2-8b0f6272af69> in <module>
      6 transformer_model = nn.Transformer(nhead=16, num_encoder_layers=12, custom_encoder = encoder)
      7 
----> 8 out = transformer_model(src, tgt)
      9 out

~/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1100         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1101                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102             return forward_call(*input, **kwargs)
   1103         # Do not call functions when jit is used
   1104         full_backward_hooks, non_full_backward_hooks = [], []

~/anaconda3/lib/python3.8/site-packages/torch/nn/modules/transformer.py in forward(self, src, tgt, src_mask, tgt_mask, memory_mask, src_key_padding_mask, tgt_key_padding_mask, memory_key_padding_mask)
    140 
    141         memory = self.encoder(src, mask=src_mask, src_key_padding_mask=src_key_padding_mask)
--> 142         output = self.decoder(tgt, memory, tgt_mask=tgt_mask, memory_mask=memory_mask,
    143                               tgt_key_padding_mask=tgt_key_padding_mask,
    144                               memory_key_padding_mask=memory_key_padding_mask)

~/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1100         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1101                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102             return forward_call(*input, **kwargs)
   1103         # Do not call functions when jit is used
   1104         full_backward_hooks, non_full_backward_hooks = [], []

~/anaconda3/lib/python3.8/site-packages/torch/nn/modules/transformer.py in forward(self, tgt, memory, tgt_mask, memory_mask, tgt_key_padding_mask, memory_key_padding_mask)
    246 
    247         for mod in self.layers:
--> 248             output = mod(output, memory, tgt_mask=tgt_mask,
    249                          memory_mask=memory_mask,
    250                          tgt_key_padding_mask=tgt_key_padding_mask,

~/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1100         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1101                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102             return forward_call(*input, **kwargs)
   1103         # Do not call functions when jit is used
   1104         full_backward_hooks, non_full_backward_hooks = [], []

~/anaconda3/lib/python3.8/site-packages/torch/nn/modules/transformer.py in forward(self, tgt, memory, tgt_mask, memory_mask, tgt_key_padding_mask, memory_key_padding_mask)
    450         else:
    451             x = self.norm1(x + self._sa_block(x, tgt_mask, tgt_key_padding_mask))
--> 452             x = self.norm2(x + self._mha_block(x, memory, memory_mask, memory_key_padding_mask))
    453             x = self.norm3(x + self._ff_block(x))
    454 

~/anaconda3/lib/python3.8/site-packages/torch/nn/modules/transformer.py in _mha_block(self, x, mem, attn_mask, key_padding_mask)
    467     def _mha_block(self, x: Tensor, mem: Tensor,
    468                    attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor]) -> Tensor:
--> 469         x = self.multihead_attn(x, mem, mem,
    470                                 attn_mask=attn_mask,
    471                                 key_padding_mask=key_padding_mask,

~/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1100         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1101                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102             return forward_call(*input, **kwargs)
   1103         # Do not call functions when jit is used
   1104         full_backward_hooks, non_full_backward_hooks = [], []

~/anaconda3/lib/python3.8/site-packages/torch/nn/modules/activation.py in forward(self, query, key, value, key_padding_mask, need_weights, attn_mask)
   1001                 v_proj_weight=self.v_proj_weight)
   1002         else:
-> 1003             attn_output, attn_output_weights = F.multi_head_attention_forward(
   1004                 query, key, value, self.embed_dim, self.num_heads,
   1005                 self.in_proj_weight, self.in_proj_bias,

~/anaconda3/lib/python3.8/site-packages/torch/nn/functional.py in multi_head_attention_forward(query, key, value, embed_dim_to_check, num_heads, in_proj_weight, in_proj_bias, bias_k, bias_v, add_zero_attn, dropout_p, out_proj_weight, out_proj_bias, training, key_padding_mask, need_weights, attn_mask, use_separate_proj_weight, q_proj_weight, k_proj_weight, v_proj_weight, static_k, static_v)
   4986     #
   4987     if not use_separate_proj_weight:
-> 4988         q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias)
   4989     else:
   4990         assert q_proj_weight is not None, "use_separate_proj_weight is True but q_proj_weight is None"

~/anaconda3/lib/python3.8/site-packages/torch/nn/functional.py in _in_projection_packed(q, k, v, w, b)
   4744             else:
   4745                 b_q, b_kv = b.split([E, E * 2])
-> 4746             return (linear(q, w_q, b_q),) + linear(k, w_kv, b_kv).chunk(2, dim=-1)
   4747     else:
   4748         w_q, w_k, w_v = w.chunk(3)

~/anaconda3/lib/python3.8/site-packages/torch/nn/functional.py in linear(input, weight, bias)
   1846     if has_torch_function_variadic(input, weight, bias):
   1847         return handle_torch_function(linear, (input, weight, bias), input, weight, bias=bias)
-> 1848     return torch._C._nn.linear(input, weight, bias)
   1849 
   1850 

RuntimeError: mat1 and mat2 shapes cannot be multiplied (320x256 and 512x1024)

Do you have any idea of what I might do wrong?

@Codophile1 I dont think the transformer model is based on encoder and decoder having different output features

According to the research paper

In “encoder-decoder attention” layers, the queries come from the previous decoder layer, and the memory keys and values come from the output of the encoder. This allows every position in the decoder to attend over all positions in the input sequence. This mimics the typical encoder-decoder attention mechanisms in sequence-to-sequence models such as [38, 2, 9]

We can see that the attention memory key and values for the decoder come from the encoder. Due to this we need to ensure the same vector space so that the attention K V can also be shared