Transformer Error: shape '[-1, 32, 64]' is invalid for input of size 12800

Here is my Transformer

import math
import torch
from torch import nn
from torch.nn import TransformerEncoder, TransformerEncoderLayer, TransformerDecoder, TransformerDecoderLayer
from torch.nn.modules.normalization import LayerNorm

class Transformer(nn.Module):
    def __init__(self, input_dim: int, output_dim: int, d_model: int = 512, num_head: int = 8, num_e_layer: int = 6,
                 num_d_layer: int = 6, ff_dim: int = 2048, drop_out: float = 0.1):
            input_dim: Size of the vocab of the input
            output_dim: Size of the vocab for output
            num_head: Number of heads in mutliheaded attention models
            num_e_layer: Number of sub-encoder layers
            num_d_layer: Number of sub-decoder layers
            ff_dim: Dimension of feedforward network in mulihead models
            d_model: The dimension to embed input and output features into
            drop_out: The drop out percentage
        super(Transformer, self).__init__()
        self.decoder = TDecoder(output_dim, d_model, num_head, ff_dim, num_d_layer, drop_out)
        self.encoder = TEncoder(input_dim, d_model, num_head, ff_dim, num_e_layer, drop_out)

    def forward(self, src: torch.Tensor, trg: torch.Tensor)
        enc_out = self.encoder(src)
        return self.decoder(trg, enc_out)

class TEncoder(nn.Module):
    def __init__(self, input_dim: int, d_model: int, num_head: int, ff_dim: int, num_layers: int, drop_out: float):
        super(TEncoder, self).__init__()
        self.pos_encoder = PositionalEncoding(d_model, drop_out)
        self.embed = nn.Embedding(input_dim, d_model)
        # Encoder layer is the multiheaded attention part of the transformer
        # Norm is the normalization after every multiheaded portion
        # ff_dim is the dimension for feedforward network at the end
        layer = TransformerEncoderLayer(d_model, num_head, ff_dim, drop_out)
        norm = LayerNorm(d_model)
        self.encoder = TransformerEncoder(layer, num_layers, norm)
    def forward(self, src: torch.Tensor):
        src_embed = self.pos_encoder(self.embed(src))
        return self.encoder(src_embed)

class TDecoder(nn.Module)
    def __init__(self, input_dim: int, d_model: int, num_head: int, ff_dim: int, num_layers: int, drop_out: float):
        self.pos_encoder = PositionalEncoding(d_model, drop_out)
        self.embed = nn.Embedding(input_dim, d_model)
        # Same as Encoder
        layer = TransformerDecoderLayer(d_model, num_head, ff_dim, drop_out)
        norm = LayerNorm(d_model)
        self.decoder = TransformerDecoder(layer, num_layers, norm)
        # Ends with a linear layer and a softmax
        self.linear = nn.Linear(d_model, input_dim)
        self.softmax = nn.Softmax(dim=2)

    def forward(self, trg: torch.Tensor, encoder_output: torch.Tensor):
        dec_mask = self._generate_square_subsequent_mask(len(trg))
        trg_embed = self.pos_encoder(self.embed(trg))
        dec_out = self.decoder(trg_embed, encoder_output, tgt_mask=dec_mask)
        output = self.linear(dec_out)
        return self.softmax(output)

    def _generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, drop_out=0.1, max_len=200):
        super(PositionalEncoding, self).__init__()
        self.drop_out = nn.Dropout(p=drop_out)
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x +[:x.size(0), :]
        return self.drop_out(x)

Just to make sure I implemented it correctly I’m testing to see if I can put just a made up seq of length 5 and put that through the encoder, then pretend it generated 2 characters already, with which I already took the best probs from both and appending them then put them through the decoder.

input = torch.ones(5).type(torch.LongTensor)
trg = torch.zeros(4).type(torch.LongTensor)
transformer = Transformer(2, 1)
output = transformer.forward(input, trg)

In this example the encoder input dim is 2 and decoder input dim is 1. For some reason I get the above error.