Forward() takes 2 positional arguments but 3 were given for predefined Transformer Decoder layer

I am using predefined Transformer Deccoder layer from pytorch and i am feeding the output from the encoder to the decoder layer, But i am getting ‘forward() takes 2 positional arguments but 3 were given’. I tried using torchsummary but still facing the same issue.

class Decoder(nn.Module):
  def __init__(
      self, 
      in_channels:int=1,
      patch_size:int=16,
      num_transformer_layers:int=6,
      embedding_dim:int=768,
      mlp_size:int=3072,
      num_heads:int=12,
      attn_dropout:int=0.1,
      mlp_dropout:int=0.1,
      embedding_dropout:int=0.1,
      out_features:int=512
  ):
    super().__init__()
    self.DecoderPrenet = Prenet(input_size=embedding_dim, output_size=out_features*2, hidden_size=embedding_dim)
    self.embedding = PatchEmbedding(in_channels=in_channels,
                                    patch_size=patch_size,
                                    embedding_dim=embedding_dim)
    self.pe = PositionalEncoding(d_model=768)
    self.embedding_dropout = nn.Dropout(p=embedding_dropout)
    self.mel_linear = nn.Linear(embedding_dim, hp.num_mels * hp.outputs_per_step)
    self.transformer_decoder = nn.Sequential(*[nn.TransformerDecoderLayer(d_model=embedding_dim,
                                                                            nhead=num_heads,
                                                                            dim_feedforward=mlp_size,
                                                                            dropout=mlp_dropout, activation='gelu',
                                                                            batch_first=True, norm_first=True) for _ in range(num_transformer_layers)])
    self.PostConvNet = PostConvNet(num_hidden=embedding_dim)
    self.encoder = SpeechtoText()

  def forward(self, x, src):
      batch_size = 16
      memory = self.encoder(src)
      x = self.embedding(x)
      x = self.DecoderPrenet(x)
      x = self.pe(x)
      expand = x.expand(batch_size, -1, -1)
      x = self.embedding_dropout(expand)
      x = self.transformer_decoder(x, memory)
      mel_out = self.mel_linear(x)
      y = torch.flatten(mel_out, start_dim=0, end_dim=1)
      postnet_input = y.transpose(0, 1)
      postconvnet = self.PostConvNet(postnet_input)
      out = postnet_input + out
      out = out.transpose(1, 0)
      k = out.shape[1]
      out = torch.reshape(out, (batch_size, m, k))
      return out

Please share the part of code where you put data into the model.

I.e


model_output=model(data1, data2, data3)

Hi i was checking with torch summary and with random tensor. This is what i have done

t2s = Decoder(in_channels=1,
      patch_size=16,
      num_transformer_layers=6,
      embedding_dim=768,
      mlp_size=3072,
      num_heads=12,
      attn_dropout=0.1,
      mlp_dropout=0.1,
      embedding_dropout=0.1,
      out_features=512)
from torchinfo import summary

summary(model=t2s,
        input_size=[(300, 1, 80), (409, 1, 80)],
        col_names=["input_size","output_size","num_params","trainable"],
        col_width=20,
        row_settings=["var_names"])

As you have several Pytorch modules in your model which each have their own forward pass, including your model, it’s not clear which is the problem from what you’ve shared so far. Can you copy the exact error message?

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
/usr/local/lib/python3.8/dist-packages/torchinfo/torchinfo.py in forward_pass(model, x, batch_dim, cache_forward_pass, device, mode, **kwargs)
    286             if isinstance(x, (list, tuple)):
--> 287                 _ = model.to(device)(*x, **kwargs)
    288             elif isinstance(x, dict):

5 frames
/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1207 
-> 1208         result = forward_call(*input, **kwargs)
   1209         if _global_forward_hooks or self._forward_hooks:

<ipython-input-41-67743ea64760> in forward(self, x, src)
     47       x = self.embedding_dropout(expand)
---> 48       x = self.transformer_decoder(x, memory)
     49       mel_out = self.mel_linear(x)

/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1207 
-> 1208         result = forward_call(*input, **kwargs)
   1209         if _global_forward_hooks or self._forward_hooks:

TypeError: forward() takes 2 positional arguments but 3 were given

The above exception was the direct cause of the following exception:

RuntimeError                              Traceback (most recent call last)
<ipython-input-42-5f31fbbf486c> in <module>
     11 from torchinfo import summary
     12 
---> 13 summary(model=t2s,
     14         input_size=[(300, 1, 80), (409, 1, 80)],
     15         col_names=["input_size","output_size","num_params","trainable"],

/usr/local/lib/python3.8/dist-packages/torchinfo/torchinfo.py in summary(model, input_size, input_data, batch_dim, cache_forward_pass, col_names, col_width, depth, device, dtypes, mode, row_settings, verbose, **kwargs)
    215         input_data, input_size, batch_dim, device, dtypes
    216     )
--> 217     summary_list = forward_pass(
    218         model, x, batch_dim, cache_forward_pass, device, model_mode, **kwargs
    219     )

/usr/local/lib/python3.8/dist-packages/torchinfo/torchinfo.py in forward_pass(model, x, batch_dim, cache_forward_pass, device, mode, **kwargs)
    294     except Exception as e:
    295         executed_layers = [layer for layer in summary_list if layer.executed]
--> 296         raise RuntimeError(
    297             "Failed to run torchinfo. See above stack traces for more details. "
    298             f"Executed layers up to: {executed_layers}"

RuntimeError: Failed to run torchinfo. See above stack traces for more details. Executed layers up to: [SpeechtoText: 1, PatchEmbedding: 2, Conv1d: 3, PositionalEncoding: 2, Dropout: 2, Sequential: 2, TransformerEncoderLayer: 3, TransformerEncoderLayer: 3, TransformerEncoderLayer: 3, TransformerEncoderLayer: 3, TransformerEncoderLayer: 3, TransformerEncoderLayer: 3, PatchEmbedding: 1, Conv1d: 2, Prenet: 1, Sequential: 2, Linear: 3, ReLU: 3, Dropout: 3, Linear: 3, ReLU: 3, Dropout: 3, PositionalEncoding: 1, Dropout: 1]

This is the error i got. I am facing the issue here where i am giving the ecoder output to the decoder along with the decoder input.

---> 48 x = self.transformer_decoder(x, memory)

Seems the problem is coming from SpeechtoText() as it is giving two outputs. Can you share your code for this module? Or link where you are getting this from.

I wrote this block and i tried running it through torchsummary i was getting proper output

class SpeechtoText(nn.Module):
  def __init__(
      self, 
      in_channels:int=1,
      patch_size:int=16,
      num_transformer_layers:int=6,
      embedding_dim:int=768,
      mlp_size:int=3072,
      num_heads:int=12,
      attn_dropout:int=0.1,
      mlp_dropout:int=0.1,
      embedding_dropout:int=0.1,
      out_features:int=512
  ):
    super().__init__()
    self.embedding = PatchEmbedding(in_channels=in_channels,
                                    patch_size=patch_size,
                                    embedding_dim=embedding_dim)
    self.pe = PositionalEncoding(d_model=768)
    self.embedding_dropout = nn.Dropout(p=embedding_dropout)

    self.transformer_encoder = nn.Sequential(*[nn.TransformerEncoderLayer(d_model=embedding_dim,
                                                                            nhead=num_heads,
                                                                            dim_feedforward=mlp_size,
                                                                            dropout=mlp_dropout, activation='gelu',
                                                                            batch_first=True, norm_first=True) for _ in range(num_transformer_layers)])
    """
    self.transformer_encoder = nn.Sequential(*[Encoder(embedding_dim=embedding_dim,
                                                                            num_heads=num_heads,
                                                                            mlp_size=mlp_size,
                                                                            mlp_dropout=mlp_dropout) for _ in range(num_transformer_layers)])
    """
    self.output = nn.Sequential(
        nn.LayerNorm(normalized_shape=embedding_dim),
        nn.Linear(in_features=embedding_dim,
                  out_features=out_features)
    )
  def forward(self, x):
    batch_size = 16
    x = self.embedding(x)
    x = self.pe(x)
    expand = x.expand(batch_size, -1, -1)
    x = self.embedding_dropout(expand)
    x = self.transformer_encoder(x)
    return x

That seems to check out okay.

Try this and see what it prints:

print(x.size(), memory.size()) # <<<< new line
x = self.transformer_decoder(x, memory)
      
self.transformer_decoder = nn.Sequential(*[nn.TransformerDecoderLayer(d_model=embedding_dim,
                                                                            nhead=num_heads,
                                                                            dim_feedforward=mlp_size,
                                                                            dropout=mlp_dropout, activation='gelu',
                                                                            batch_first=True, norm_first=True) for _ in range(num_transformer_layers)])

Try changing this to:


self.transformer_decoder = nn.Sequential(*[nn.TransformerDecoderLayer(d_model=embedding_dim, num_layers=num_transformer_layers,
                                                                            nhead=num_heads,
                                                                            dim_feedforward=mlp_size,
                                                                            dropout=mlp_dropout, activation='gelu',
                                                                            batch_first=True, norm_first=True)])

Indented correctly, of course.

Printing the shapes before i am getting a output of torch.Size([16, 300, 768]) torch.Size([16, 409, 768]) which i think is correct and i dont think num_layers=num_transformer_layers, will work as this module is predefined by pytorch and they dont have and num_layers argument.

This reproduces your issue:

import torch
import torch.nn as nn

num_transformer_layers=2
transformer_decoder = nn.Sequential(*[nn.TransformerDecoderLayer(d_model=512, nhead=8,activation='gelu',
                                                                            batch_first=True) for _ in range(num_transformer_layers)])
src = torch.rand((32, 10, 512))
tgt = torch.rand((32, 20, 512))

out = transformer_decoder(src, tgt)
print(out.size())
    

I have no idea who wrote that code for your Decoder class. Where did you find it? And what version of Pytorch are you using?

I am using the latest pytorch and the class that i amusing can be found here TransformerDecoderLayer — PyTorch 1.13 documentation

I am referring to the entire “Decoder” class above in asking where you got this code. I’m guessing you got it off a GitHub repo and are just having a Pytorch version mismatch. Maybe they used an earlier version.

No i wrote the code myself

This particular instantiation of the TransformerDecoderLayer doesn’t work. The TransformerDecoderLayer takes two inputs and gives one output. So on your second iteration of it, you’re getting this error, because it’s taking one output from the previous decoder layer.

You can test this is the issue by changing that line to:

self.transformer_decoder = nn.TransformerDecoderLayer(d_model=embedding_dim,
                                                                            nhead=num_heads,
                                                                            dim_feedforward=mlp_size,
                                                                            dropout=mlp_dropout, activation='gelu',
                                                                            batch_first=True, norm_first=True)
    

Ok understood the problem but is there a workaround to it or do i have to give one decoder layer

Please check the documentation for the TransformerDecoderLayer. It has an argument num_layers, so that you can use the native method to build more layers, which is what I was suggesting earlier.

Sorry, correction. There is a separate class that does not append the word “Layer” that you can use:

https://pytorch.org/docs/stable/generated/torch.nn.TransformerDecoder.html

decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)
memory = torch.rand(10, 32, 512)
tgt = torch.rand(20, 32, 512)
out = transformer_decoder(tgt, memory)

Yeah done tahnk you so much for your help

1 Like