I’m trying to train a Transformer model with source and target sequences having feature vectors of different sizes.
As the default Transformer use rely on same size feature vectors, I use a custom encoder.
Here’s a minimal reproducible example:
import torch
import torch.nn as nn
src = torch.rand((10, 32, 256))
tgt = torch.rand((20, 32, 512))
encoder_layer = nn.TransformerEncoderLayer(256, 8, 512, 0.5)
encoder_norm = nn.LayerNorm(256)
encoder = nn.TransformerEncoder(encoder_layer, 8, encoder_norm)
transformer_model = nn.Transformer(nhead=16, num_encoder_layers=12, custom_encoder = encoder)
out = transformer_model(src, tgt)
I get the following error message:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-2-8b0f6272af69> in <module>
6 transformer_model = nn.Transformer(nhead=16, num_encoder_layers=12, custom_encoder = encoder)
7
----> 8 out = transformer_model(src, tgt)
9 out
~/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1101 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102 return forward_call(*input, **kwargs)
1103 # Do not call functions when jit is used
1104 full_backward_hooks, non_full_backward_hooks = [], []
~/anaconda3/lib/python3.8/site-packages/torch/nn/modules/transformer.py in forward(self, src, tgt, src_mask, tgt_mask, memory_mask, src_key_padding_mask, tgt_key_padding_mask, memory_key_padding_mask)
140
141 memory = self.encoder(src, mask=src_mask, src_key_padding_mask=src_key_padding_mask)
--> 142 output = self.decoder(tgt, memory, tgt_mask=tgt_mask, memory_mask=memory_mask,
143 tgt_key_padding_mask=tgt_key_padding_mask,
144 memory_key_padding_mask=memory_key_padding_mask)
~/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1101 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102 return forward_call(*input, **kwargs)
1103 # Do not call functions when jit is used
1104 full_backward_hooks, non_full_backward_hooks = [], []
~/anaconda3/lib/python3.8/site-packages/torch/nn/modules/transformer.py in forward(self, tgt, memory, tgt_mask, memory_mask, tgt_key_padding_mask, memory_key_padding_mask)
246
247 for mod in self.layers:
--> 248 output = mod(output, memory, tgt_mask=tgt_mask,
249 memory_mask=memory_mask,
250 tgt_key_padding_mask=tgt_key_padding_mask,
~/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1101 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102 return forward_call(*input, **kwargs)
1103 # Do not call functions when jit is used
1104 full_backward_hooks, non_full_backward_hooks = [], []
~/anaconda3/lib/python3.8/site-packages/torch/nn/modules/transformer.py in forward(self, tgt, memory, tgt_mask, memory_mask, tgt_key_padding_mask, memory_key_padding_mask)
450 else:
451 x = self.norm1(x + self._sa_block(x, tgt_mask, tgt_key_padding_mask))
--> 452 x = self.norm2(x + self._mha_block(x, memory, memory_mask, memory_key_padding_mask))
453 x = self.norm3(x + self._ff_block(x))
454
~/anaconda3/lib/python3.8/site-packages/torch/nn/modules/transformer.py in _mha_block(self, x, mem, attn_mask, key_padding_mask)
467 def _mha_block(self, x: Tensor, mem: Tensor,
468 attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor]) -> Tensor:
--> 469 x = self.multihead_attn(x, mem, mem,
470 attn_mask=attn_mask,
471 key_padding_mask=key_padding_mask,
~/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1101 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102 return forward_call(*input, **kwargs)
1103 # Do not call functions when jit is used
1104 full_backward_hooks, non_full_backward_hooks = [], []
~/anaconda3/lib/python3.8/site-packages/torch/nn/modules/activation.py in forward(self, query, key, value, key_padding_mask, need_weights, attn_mask)
1001 v_proj_weight=self.v_proj_weight)
1002 else:
-> 1003 attn_output, attn_output_weights = F.multi_head_attention_forward(
1004 query, key, value, self.embed_dim, self.num_heads,
1005 self.in_proj_weight, self.in_proj_bias,
~/anaconda3/lib/python3.8/site-packages/torch/nn/functional.py in multi_head_attention_forward(query, key, value, embed_dim_to_check, num_heads, in_proj_weight, in_proj_bias, bias_k, bias_v, add_zero_attn, dropout_p, out_proj_weight, out_proj_bias, training, key_padding_mask, need_weights, attn_mask, use_separate_proj_weight, q_proj_weight, k_proj_weight, v_proj_weight, static_k, static_v)
4986 #
4987 if not use_separate_proj_weight:
-> 4988 q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias)
4989 else:
4990 assert q_proj_weight is not None, "use_separate_proj_weight is True but q_proj_weight is None"
~/anaconda3/lib/python3.8/site-packages/torch/nn/functional.py in _in_projection_packed(q, k, v, w, b)
4744 else:
4745 b_q, b_kv = b.split([E, E * 2])
-> 4746 return (linear(q, w_q, b_q),) + linear(k, w_kv, b_kv).chunk(2, dim=-1)
4747 else:
4748 w_q, w_k, w_v = w.chunk(3)
~/anaconda3/lib/python3.8/site-packages/torch/nn/functional.py in linear(input, weight, bias)
1846 if has_torch_function_variadic(input, weight, bias):
1847 return handle_torch_function(linear, (input, weight, bias), input, weight, bias=bias)
-> 1848 return torch._C._nn.linear(input, weight, bias)
1849
1850
RuntimeError: mat1 and mat2 shapes cannot be multiplied (320x256 and 512x1024)
Do you have any idea of what I might do wrong?