Garbage output at inference stage

Greetings all,

I’m having some trouble with my code. I’m building a transformer, see, and I’ve got some problems at the inference stage. Here’s my code:

class SelfAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(SelfAttention, self).__init__()
        self.embed_size=embed_size
        self.heads=heads
        self.head_dim=embed_size // heads
        
        assert (self.head_dim*heads==embed_size)
        
        self.values=nn.Linear(self.embed_size, self.embed_size, bias=False)
        self.keys=nn.Linear(self.embed_size, self.embed_size, bias=False)
        self.queries=nn.Linear(self.embed_size, self.embed_size, bias=False)
        self.fc_out=nn.Linear(heads*self.head_dim, embed_size)
    
    def forward(self, values, keys, query, mask):
        N=query.shape[0]
        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]
        
        values=self.values(values)
        values= values.reshape(N, value_len, self.heads, self.head_dim)
        
        keys=self.keys(keys)
        keys= keys.reshape(N, key_len, self.heads, self.head_dim)
        
        queries=self.queries(query)
        queries=queries.reshape(N, query_len, self.heads, self.head_dim)

        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
        #energy shape: (N, heads, query_len, key_len)
        
        if mask is not None:
            energy=energy.masked_fill(mask==0, float("-1e20"))
        
        attention=torch.softmax(energy/(self.embed_size**(1/2)), dim=3)
        
        out=torch.einsum("nhql,nlhd->nqhd", [attention,values]).reshape(N, query_len, self.heads*self.head_dim)
        #out shape= (N, query_len, heads, head_dim)
        
        out=self.fc_out(out)
        return out

class TransformerBlock(nn.Module):
    def __init__(self, embed_size, heads, dropout, forward_expansion):
        super(TransformerBlock, self).__init__()
        self.attention=SelfAttention(embed_size, heads)
        self.norm1=nn.LayerNorm(embed_size)
        self.norm2=nn.LayerNorm(embed_size)
        
        self.feed_forward=nn.Sequential(
            nn.Linear(embed_size, forward_expansion*embed_size),
            nn.ReLU(),
            nn.Linear(forward_expansion*embed_size, embed_size)
        )
        self.dropout=nn.Dropout(dropout)
        
    def forward(self, value, key, query, mask):
        attention=self.attention(value, key, query, mask)
        x=self.norm1(query+self.dropout(attention))
        forward=self.feed_forward(x)
        out=self.norm2(x+self.dropout(forward))
        return out

class Encoder(nn.Module):
    def __init__(
        self,
        src_vocab_size,
        embed_size,
        num_layers,
        heads,
        device,
        forward_expansion,
        dropout,
        max_length
    ):
        super(Encoder, self).__init__()
        self.embed_size=embed_size
        self.device=device
        self.word_embedding=nn.Embedding(src_vocab_size, embed_size) # Embeds src to size vocab and embed 
        self.position_embedding=nn.Embedding(max_length, embed_size) #
        
        self.layers = nn.ModuleList(
            [
                TransformerBlock(
                    embed_size,
                    heads,
                    dropout=dropout,
                    forward_expansion=forward_expansion
                )
            for _ in range(num_layers)]
        )
        self.dropout=nn.Dropout(dropout)
        
    def forward(self, x, mask):
        N, seq_length=x.shape # N is the batch size, seq_length is the size of n-gram
        positions=torch.arange(0, seq_length).expand(N, seq_length).to(self.device) #positions 
        
        out= self.dropout((self.word_embedding(x)+self.position_embedding(positions))) #dropout of word and position embedder
        
        for layer in self.layers:
            out=layer(out, out, out, mask)
        return out

class DecoderBlock(nn.Module):
    def __init__(self, embed_size, heads, forward_expansion, dropout, device):
        super(DecoderBlock, self).__init__()
        self.norm = nn.LayerNorm(embed_size)
        self.attention= SelfAttention(embed_size, heads)
        self.transformer_block = TransformerBlock(embed_size, heads, dropout, forward_expansion)
        self.dropout=nn.Dropout(dropout)
        
    def forward(self, x, value, key, src_mask, trg_mask):
        attention=self.attention(x,x,x,trg_mask)
        query=self.norm(x+self.dropout(attention))
        out=self.transformer_block(value, key, query, src_mask)
        return out

class Decoder(nn.Module):
    def __init__(
        self,
        trg_vocab_size,
        embed_size,
        num_layers,
        heads,
        forward_expansion,
        dropout,
        device,
        max_length
    ):
        super(Decoder, self).__init__()
        self.device = device
        self.word_embedding=nn.Embedding(trg_vocab_size, embed_size)
        self.position_embedding=nn.Embedding(max_length, embed_size)
        
        self.layers=nn.ModuleList(
            [DecoderBlock(embed_size, heads, forward_expansion, dropout, device) for _ in range(num_layers)]
        )
        
        self.fc_out=nn.Linear(embed_size, trg_vocab_size)
        self.dropout=nn.Dropout(dropout)
        
    def forward(self, x, enc_out, src_mask, trg_mask):
        N, seq_length=x.shape
        positions=torch.arange(0, seq_length).expand(N, seq_length).to(self.device)
        x=self.dropout((self.word_embedding(x)+self.position_embedding(positions)))
        
        for layer in self.layers:
            x= layer(x, enc_out, enc_out, src_mask, trg_mask)
            
        out=self.fc_out(x)
        return out

class Transformer(nn.Module):
    def __init__(
        self,
        src_vocab_size,
        trg_vocab_size,
        src_pad_idx,
        trg_pad_idx,
        embed_size=128,
        num_layers=3,
        forward_expansion=2,
        heads=4,
        dropout=0.05,
        device="cpu",
        max_length=6
    ):
        super(Transformer, self).__init__()
        
        self.encoder = Encoder(
            src_vocab_size,
            embed_size,
            num_layers,
            heads,
            device,
            forward_expansion,
            dropout,
            max_length
        )
        
        self.decoder = Decoder(
            trg_vocab_size,
            embed_size,
            num_layers,
            heads,
            forward_expansion,
            dropout,
            device,
            max_length,
        )
        
        self.src_pad_idx=src_pad_idx
        self.trg_pad_idx= trg_pad_idx
        self.device=device
        
        self.sm_out=nn.Softmax(dim=2)
        
    def make_src_mask(self, src):
        src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)
        #N,1,1,src_len
        return src_mask.to(self.device)
    
    def make_trg_mask(self, trg):
        N, trg_len=trg.shape
        trg_mask=torch.tril(torch.ones((trg_len, trg_len))).expand(
            N, 1, trg_len, trg_len
        )
        return trg_mask.to(self.device)
    
    def forward(self, src, trg):
        src_mask = self.make_src_mask(src)
        trg_mask=self.make_trg_mask(trg)
        enc_src= self.encoder(src, src_mask)
        out=self.decoder(trg, enc_src, src_mask, trg_mask)
        out=self.sm_out(out)
        return out

Then I have a batching function and random input generation.

def get_batches(arr_x, arr_y, batch_size):      
    # iterate through the arrays
    prv = 0
    for n in range(batch_size, arr_x.shape[0]+1, batch_size):
        x = arr_x[prv:n,:]
        y = arr_y[prv:n,:]
        prv = n
        yield x, y

import torch.optim as optim

device= torch.device("cuda" if torch.cuda.is_available() else "cpu")

src2=np.random.randint(8,size=(128,6))+1
trg2=src2+5
print(trg2) #outputs a 128,6 array of random ints 5 more than src2

Training for 100 epochs, lr is high so it catches on pretty quickly

model=Transformer(src_vocab_size=14, trg_vocab_size=14, src_pad_idx=0, trg_pad_idx=0).to(device)

def train_model(src, trg, epochs=101, batch_size=32, classes=14):
    print(model)
    opt=torch.optim.SGD(model.parameters(), lr=0.1)
    loss_fn=nn.functional.cross_entropy
    
    total_loss=0
    train_loss_list, validation_loss_list = [], []
    total_tokens=0
    for e in range(epochs):
        model.train()
        loss=0
        if e%50==0:
            print("-"*25, f"Epoch {e + 1}","-"*25)
        for x, y in get_batches(src, trg, batch_size):
            x, y = torch.from_numpy(x.astype('int64')).to(device), torch.from_numpy(y.astype('int64')).to(device)
            pred=model(x, y)
            if e%50==0:
                print(pred.shape)
                print(torch.argmax(pred,dim=2)[0])
                print(y[0])
            loss=loss_fn(pred, torch.nn.functional.one_hot(y, num_classes=classes).type(torch.FloatTensor))
            total_tokens+=batch_size
            opt.zero_grad()
            loss.backward()
            opt.step()
        
            total_loss += loss.detach().item()
            train_loss_list += [total_loss/total_tokens]

        if e%50==0: print(f"Training loss: {loss/batch_size:.4f}, Total loss: {total_loss:.4f}")
        print('...')
        
    return train_loss_list
train_loss_list = train_model(src=src2,trg=trg2)

Now for my woes. When I try to infer with the model on a simple input, I get garbage output.

def greedy_decode(model, src, max_len):
    ys = torch.zeros(1, 6).type_as(src)
    print(ys)
    for i in range(max_len):
        prob=model(src, ys)
        print(prob)
        next_word = torch.argmax(prob, dim = 2)
        print(next_word)
        next_word = next_word.data[0]
        print(next_word.data[i])
        ys[0,i] += next_word.data[i]
    return ys

model.eval()
src = torch.LongTensor([[2,3,4,5,6,7]])
print(greedy_decode(model, src, max_len=6))
tensor([[0, 0, 0, 0, 0, 0]])
tensor([[[0.0288, 0.0460, 0.0148, 0.0124, 0.0357, 0.0194, 0.0374, 0.0296,
          0.1791, 0.1655, 0.1428, 0.0337, 0.0816, 0.1730],
         [0.0349, 0.0780, 0.0151, 0.0332, 0.0361, 0.0238, 0.0523, 0.0159,
          0.0896, 0.0648, 0.1253, 0.0349, 0.2191, 0.1770],
         [0.0480, 0.0754, 0.0158, 0.0214, 0.0664, 0.0754, 0.0651, 0.0177,
          0.2211, 0.0690, 0.0944, 0.0270, 0.1367, 0.0665],
         [0.0338, 0.0433, 0.0159, 0.0246, 0.0476, 0.0320, 0.0393, 0.0061,
          0.1299, 0.1012, 0.1060, 0.0221, 0.1791, 0.2190],
         [0.0320, 0.0527, 0.0133, 0.0271, 0.0659, 0.0392, 0.0235, 0.0104,
          0.0354, 0.1256, 0.0601, 0.0214, 0.1289, 0.3646],
         [0.0267, 0.0525, 0.0186, 0.0310, 0.0473, 0.0292, 0.0658, 0.0112,
          0.1424, 0.1685, 0.1065, 0.0492, 0.0946, 0.1564]]],
       grad_fn=<SoftmaxBackward0>)
tensor([[ 8, 12,  8, 13, 13,  9]])
tensor(8)
tensor([[[3.4346e-04, 4.4885e-04, 2.9639e-04, 5.8285e-05, 1.6682e-04,
          1.0215e-04, 1.4989e-04, 5.4101e-04, 9.9632e-01, 4.6345e-04,
          1.8907e-04, 4.5108e-04, 3.3548e-04, 1.3614e-04],
         [3.0471e-02, 1.0963e-01, 2.6189e-02, 2.8180e-02, 4.6828e-02,
          2.7923e-02, 4.8535e-02, 1.7435e-02, 2.1706e-01, 5.1900e-02,
          1.0358e-01, 3.3658e-02, 1.9026e-01, 6.8364e-02],
         [3.7262e-02, 7.9151e-02, 1.8621e-02, 1.5883e-02, 6.6042e-02,
          7.3598e-02, 4.7058e-02, 1.4851e-02, 3.6171e-01, 5.1541e-02,
          7.0141e-02, 2.7503e-02, 1.0729e-01, 2.9345e-02],
         [3.1110e-02, 5.1011e-02, 1.9348e-02, 2.1894e-02, 5.1082e-02,
          3.4915e-02, 3.4925e-02, 5.6201e-03, 2.1383e-01, 8.9307e-02,
          9.1074e-02, 2.3128e-02, 1.9180e-01, 1.4095e-01],
         [3.1558e-02, 6.9342e-02, 1.7462e-02, 2.6760e-02, 7.8564e-02,
          4.4574e-02, 2.2642e-02, 1.0699e-02, 5.6768e-02, 1.4076e-01,
          5.7005e-02, 2.4431e-02, 1.3645e-01, 2.8298e-01],
         [2.3861e-02, 6.0398e-02, 2.2116e-02, 2.6945e-02, 4.8900e-02,
          2.9926e-02, 6.3012e-02, 9.8275e-03, 2.0789e-01, 1.6895e-01,
          9.8798e-02, 4.9619e-02, 8.3311e-02, 1.0644e-01]]],
       grad_fn=<SoftmaxBackward0>)
tensor([[ 8,  8,  8,  8, 13,  8]])
tensor(8)
tensor([[[3.4346e-04, 4.4885e-04, 2.9639e-04, 5.8285e-05, 1.6682e-04,
          1.0215e-04, 1.4989e-04, 5.4101e-04, 9.9632e-01, 4.6345e-04,
          1.8907e-04, 4.5108e-04, 3.3548e-04, 1.3614e-04],
         [5.0727e-04, 7.7619e-04, 4.7299e-04, 1.5359e-04, 2.1579e-04,
          1.5301e-04, 2.3866e-04, 4.8841e-04, 9.9550e-01, 1.5122e-04,
          1.5233e-04, 4.4606e-04, 6.0918e-04, 1.3881e-04],
         [2.4091e-02, 6.6588e-02, 2.0767e-02, 1.1727e-02, 5.5241e-02,
          6.0745e-02, 3.3838e-02, 1.2357e-02, 5.2257e-01, 3.3011e-02,
          4.7262e-02, 2.3244e-02, 7.4178e-02, 1.4385e-02],
         [2.5230e-02, 5.1829e-02, 2.1725e-02, 1.9427e-02, 4.9629e-02,
          3.4386e-02, 3.0234e-02, 5.2530e-03, 3.1320e-01, 7.1912e-02,
          7.7694e-02, 2.3634e-02, 1.7828e-01, 9.7569e-02],
         [2.8949e-02, 8.1070e-02, 2.2353e-02, 2.6622e-02, 8.7250e-02,
          4.9413e-02, 2.2407e-02, 1.1211e-02, 8.7391e-02, 1.4620e-01,
          5.3241e-02, 2.5863e-02, 1.3452e-01, 2.2351e-01],
         [2.0373e-02, 6.2544e-02, 2.5682e-02, 2.4110e-02, 4.8708e-02,
          2.9638e-02, 6.1629e-02, 9.1208e-03, 2.8317e-01, 1.5011e-01,
          8.9950e-02, 4.8338e-02, 7.1531e-02, 7.5104e-02]]],
       grad_fn=<SoftmaxBackward0>)
tensor([[ 8,  8,  8,  8, 13,  8]])
tensor(8)
tensor([[[3.4346e-04, 4.4885e-04, 2.9639e-04, 5.8285e-05, 1.6682e-04,
          1.0215e-04, 1.4989e-04, 5.4101e-04, 9.9632e-01, 4.6345e-04,
          1.8907e-04, 4.5108e-04, 3.3548e-04, 1.3614e-04],
         [5.0727e-04, 7.7619e-04, 4.7299e-04, 1.5359e-04, 2.1579e-04,
          1.5301e-04, 2.3866e-04, 4.8841e-04, 9.9550e-01, 1.5122e-04,
          1.5233e-04, 4.4606e-04, 6.0918e-04, 1.3881e-04],
         [6.5134e-04, 7.3119e-04, 5.0800e-04, 1.1433e-04, 3.7585e-04,
          5.0196e-04, 3.0340e-04, 5.2323e-04, 9.9490e-01, 2.1529e-04,
          1.6897e-04, 3.4757e-04, 5.7473e-04, 8.5880e-05],
         [2.0070e-02, 5.0681e-02, 2.4773e-02, 1.8060e-02, 4.9905e-02,
          3.3661e-02, 2.7250e-02, 5.5127e-03, 4.0430e-01, 5.2429e-02,
          6.6296e-02, 2.2382e-02, 1.6011e-01, 6.4569e-02],
         [2.6628e-02, 8.8863e-02, 2.7116e-02, 2.6948e-02, 9.6683e-02,
          5.3906e-02, 2.1981e-02, 1.2533e-02, 1.2193e-01, 1.3180e-01,
          5.0072e-02, 2.6949e-02, 1.3609e-01, 1.7850e-01],
         [1.7873e-02, 6.2392e-02, 2.8520e-02, 2.2571e-02, 4.9959e-02,
          2.9230e-02, 6.0232e-02, 9.2869e-03, 3.4494e-01, 1.2532e-01,
          8.2190e-02, 4.8685e-02, 6.4112e-02, 5.4695e-02]]],
       grad_fn=<SoftmaxBackward0>)
tensor([[ 8,  8,  8,  8, 13,  8]])
tensor(8)
tensor([[[3.4346e-04, 4.4885e-04, 2.9639e-04, 5.8285e-05, 1.6682e-04,
          1.0215e-04, 1.4989e-04, 5.4101e-04, 9.9632e-01, 4.6345e-04,
          1.8907e-04, 4.5108e-04, 3.3548e-04, 1.3614e-04],
         [5.0727e-04, 7.7619e-04, 4.7299e-04, 1.5359e-04, 2.1579e-04,
          1.5301e-04, 2.3866e-04, 4.8841e-04, 9.9550e-01, 1.5122e-04,
          1.5233e-04, 4.4606e-04, 6.0918e-04, 1.3881e-04],
         [6.5134e-04, 7.3119e-04, 5.0800e-04, 1.1433e-04, 3.7585e-04,
          5.0196e-04, 3.0340e-04, 5.2323e-04, 9.9490e-01, 2.1529e-04,
          1.6897e-04, 3.4757e-04, 5.7473e-04, 8.5880e-05],
         [3.4224e-04, 3.4134e-04, 3.1641e-04, 9.3156e-05, 1.9721e-04,
          1.2940e-04, 1.0635e-04, 1.6255e-04, 9.9706e-01, 1.8321e-04,
          1.1403e-04, 2.0476e-04, 5.8318e-04, 1.6127e-04],
         [2.3426e-02, 9.6285e-02, 3.3548e-02, 2.5118e-02, 1.0165e-01,
          5.6494e-02, 2.1736e-02, 1.3782e-02, 1.7155e-01, 1.1895e-01,
          4.6460e-02, 2.7188e-02, 1.2749e-01, 1.3631e-01],
         [1.4750e-02, 6.0416e-02, 3.1490e-02, 1.9243e-02, 4.8390e-02,
          2.8151e-02, 5.5594e-02, 9.3160e-03, 4.1814e-01, 1.0462e-01,
          7.4928e-02, 4.4255e-02, 5.3903e-02, 3.6806e-02]]],
       grad_fn=<SoftmaxBackward0>)
tensor([[8, 8, 8, 8, 8, 8]])
tensor(8)
tensor([[[3.4346e-04, 4.4885e-04, 2.9639e-04, 5.8285e-05, 1.6682e-04,
          1.0215e-04, 1.4989e-04, 5.4101e-04, 9.9632e-01, 4.6345e-04,
          1.8907e-04, 4.5108e-04, 3.3548e-04, 1.3614e-04],
         [5.0727e-04, 7.7619e-04, 4.7299e-04, 1.5359e-04, 2.1579e-04,
          1.5301e-04, 2.3866e-04, 4.8841e-04, 9.9550e-01, 1.5122e-04,
          1.5233e-04, 4.4606e-04, 6.0918e-04, 1.3881e-04],
         [6.5134e-04, 7.3119e-04, 5.0800e-04, 1.1433e-04, 3.7585e-04,
          5.0196e-04, 3.0340e-04, 5.2323e-04, 9.9490e-01, 2.1529e-04,
          1.6897e-04, 3.4757e-04, 5.7473e-04, 8.5880e-05],
         [3.4224e-04, 3.4134e-04, 3.1641e-04, 9.3156e-05, 1.9721e-04,
          1.2940e-04, 1.0635e-04, 1.6255e-04, 9.9706e-01, 1.8321e-04,
          1.1403e-04, 2.0476e-04, 5.8318e-04, 1.6127e-04],
         [7.2871e-04, 9.4293e-04, 4.8325e-04, 1.6221e-04, 4.3626e-04,
          3.0433e-04, 1.8114e-04, 7.3524e-04, 9.9360e-01, 4.9332e-04,
          7.6330e-05, 5.0833e-04, 1.0233e-03, 3.2313e-04],
         [1.1731e-02, 5.6806e-02, 3.3593e-02, 1.6554e-02, 4.6226e-02,
          2.6792e-02, 4.8839e-02, 9.7105e-03, 4.9776e-01, 7.5886e-02,
          5.9116e-02, 4.2508e-02, 4.8548e-02, 2.5929e-02]]],
       grad_fn=<SoftmaxBackward0>)
tensor([[8, 8, 8, 8, 8, 8]])
tensor(8)
tensor([[8, 8, 8, 8, 8, 8]]) # expected [7,8,9,10,11,12]

Where am I going wrong? Am I on the right path here? Is my model learning anything?