For batched (3-D) `query`, expected `key` and `value` to be 3-D but found 2-D and 2-D tensors respectively

I’m trying to implement a paper which is very much inspired by transformer architectures. For this I have implemented encoder, decoder architectures, however when comparing my prediction to my label I obtain the following traceback:

  File ~\Desktop\mtl\src\architectures\unit.py:255 in <module>
    y_pred = model(inputs, label)

  File ~\Anaconda3\lib\site-packages\torch\nn\modules\module.py:1130 in _call_impl
    return forward_call(*input, **kwargs)

  File ~\Desktop\mtl\src\architectures\unit.py:192 in forward
    dec1 = self.decoder1(trg[0], out)

  File ~\Anaconda3\lib\site-packages\torch\nn\modules\module.py:1130 in _call_impl
    return forward_call(*input, **kwargs)

  File ~\Desktop\mtl\src\architectures\unit.py:160 in forward
    out1 = self.dec(out1, enc)

  File ~\Anaconda3\lib\site-packages\torch\nn\modules\module.py:1130 in _call_impl
    return forward_call(*input, **kwargs)

  File ~\Desktop\mtl\src\architectures\unit.py:120 in forward
    out3, _ = self.mha2(out2, enc, enc)

  File ~\Anaconda3\lib\site-packages\torch\nn\modules\module.py:1130 in _call_impl
    return forward_call(*input, **kwargs)

  File ~\Anaconda3\lib\site-packages\torch\nn\modules\activation.py:1153 in forward
    attn_output, attn_output_weights = F.multi_head_attention_forward(

  File ~\Anaconda3\lib\site-packages\torch\nn\functional.py:5030 in multi_head_attention_forward
    is_batched = _mha_shape_check(query, key, value, key_padding_mask, attn_mask, num_heads)

  File ~\Anaconda3\lib\site-packages\torch\nn\functional.py:4874 in _mha_shape_check
    assert key.dim() == 3 and value.dim() == 3, \

AssertionError: For batched (3-D) `query`, expected `key` and `value` to be 3-D but found 2-D and 2-D tensors respectively

Supposing I have the following training schedule with dummy data:

    import sys
    from torch.utils.data import Dataset, DataLoader
    import torch.optim as optim
    
    sys.path.append(r'C:\Users\dalbertw\Desktop\mtl\src\data')
    from dataset import MyDataset
    
    
    
    #define dataset
    x1 = torch.randn((1000,3, 28,28))
    x2 = torch.randn((1000,3,28,28))

    y1 = torch.randint(0,3,(1000,1))
    y2 = torch.randint(0,3,(1000,1))
    y3 = torch.randint(0,1,(1000,1))
    
    
    train_set = MyDataset(x = [x1,x2],
                          y = [y1,y2,y3]
                          )
    
    train_loader = DataLoader(train_set, batch_size=18, num_workers=0, shuffle=True)
    
    
    
    #define model and hyperparameters
    model = UniT()
    model.to('cpu')
    optim = SGD(model.parameters(), lr=0.01)
    
    criterion = nn.CrossEntropyLoss()
    
    model.train()
    
    
    batch = next(iter(train_loader))
    
    
    for epoch in range(50):
       for idx, data in enumerate(train_loader):
           print('Epoch: {}'.format(idx))
           
           inputs = data[:2]
           label = data[2:]
           
           optim.zero_grad()
           
           #batch = convert_tensor(batch, device='cpu', non_blocking=True)  
           y_pred = model(inputs, label)
           
           print(y_pred[0].shape)
           
           total_loss = criterion(y_pred[0], label[0]) + criterion(y_pred[1], label[1]) + criterion(y_pred[2], label[2])
           
           #losses = PCGrad(model, optim, [x1.detach(),  x2.detach()], 
           #                             [y1.detach(),y2.detach(),y3.detach()])
           #total_loss = torch.sum(torch.stack(losses))
           
           #total_loss = loss1(y1_hat, y1) + loss2(y2_hat, y2)
    
    
           total_loss.backward()
           #loss.backward()
           
           optim.step()

Further this is my current architectures for the encoder and decoder:

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride = 1):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Sequential(
                        nn.Conv2d(in_channels, out_channels, kernel_size = 3, stride = stride, padding = 1),
                        nn.BatchNorm2d(out_channels),
                        nn.ReLU())
        self.conv2 = nn.Sequential(
                        nn.Conv2d(out_channels, out_channels, kernel_size = 3, stride = 1, padding = 1),
                        nn.BatchNorm2d(out_channels))
        self.relu = nn.ReLU()
        self.out_channels = out_channels
        
    def forward(self, x):        
        out = self.conv1(x)
        out = self.conv2(out)
        out = self.relu(out)
        return out
    


class TaskSpecificEmbedding(nn.Module):
    def __init__(self, num_embed=1, embed_dim=28*28):
        super().__init__()
        self.embedding = torch.nn.Embedding(num_embed, embed_dim)
    
    def forward(self, x):
        out =  self.embedding(x)
        return out
    
    


class EncoderLayer(nn.Module):
    def __init__(self, input_size, ff_hidden=2048, num_heads=8, dropout=0.1):
        super().__init__()
        self.mha1 = torch.nn.MultiheadAttention(input_size, num_heads)
        self.norm1 = torch.nn.LayerNorm([input_size], eps=1e-6)
        self.drop1 = torch.nn.Dropout(dropout)
        self.feed_forward = nn.Sequential(
                          nn.Linear(input_size, ff_hidden),
                          nn.ReLU(),
                          nn.Linear(ff_hidden, input_size)
        )
        self.norm2 = torch.nn.LayerNorm(input_size, eps=1e-6)
        self.drop2 = torch.nn.Dropout(dropout)


    def forward(self, x):  
        x = torch.flatten(x, start_dim=1)
        # 1. compute self attention
        out1, _ = self.mha1(x, x, x)
        # 2. add and norm
        out1 = self.drop1(out1)
        out2 = self.norm1(x + out1)
        # 3. positionwise feed forward network
        out3 = self.feed_forward(out2)
        out3 = self.drop2(out3)
        # 4. add and norm
        out4 = self.norm2(out2) + out3
        return out4
    
    
   
class DecoderLayer(nn.Module):
    def __init__(self, input_size, ff_hidden=2048, head_size=256, num_heads=8, dropout=0.1):
        super().__init__()
        self.mha1 = torch.nn.MultiheadAttention(input_size, num_heads)
        self.norm1 = torch.nn.LayerNorm(input_size, eps=1e-6)
        self.drop1 = torch.nn.Dropout(dropout)
        
        self.mha2 = torch.nn.MultiheadAttention(input_size, num_heads)
        self.norm2 = torch.nn.LayerNorm(input_size, eps=1e-6)
        self.drop2 = torch.nn.Dropout(dropout)
        
        self.feed_forward = nn.Sequential(
                          nn.Linear(input_size, ff_hidden),
                          nn.ReLU(),
                          nn.Linear(ff_hidden, input_size)
        )
        self.norm3 = torch.nn.LayerNorm(input_size, eps=1e-6)
        self.drop3 = torch.nn.Dropout(dropout)
        
        
    def forward(self, x, enc):        
        out2, _ = self.mha1(x, x, x)
        out2 = self.drop1(out2)
        out2 = self.norm2(out2 + x)
        print('out2 ' + str(out2.shape))
        print('enc ' + str(enc.view(18,6272,1).shape))
        
        out3, _ = self.mha2(out2, enc, enc)
        out3 = self.drop1(out3)
       
        
        out3 = self.norm3(out3) + out2
        out3 = self.drop3(out3)
        out4 = self.feed_forward(out3)
        out4 = self.norm4(out4) + out3
        return out4
        
    
class Encoder(nn.Module):
    def __init__(self, input_size, n_layers=1):
        super().__init__()
        self.n_layers = n_layers
        
        self.backbone1 = ResidualBlock(3, 4)
        self.transformer1 = EncoderLayer(input_size=input_size)
    
    def forward(self, x):
        out1 = self.backbone1(x)
        
        for i in range(self.n_layers):
            out1 = self.transformer1(out1)
        return out1
    
class Decoder(nn.Module):
    def __init__(self, input_size, n_layers=1, num_embed=1, embed_dim=28*28):
        super().__init__()
        self.n_layers = n_layers
        
        self.embed = TaskSpecificEmbedding(num_embed, embed_dim)
        self.dec = DecoderLayer(input_size)
        #self.linear = nn.LazyLinear(3)
        
    def forward(self, trg, enc):
        #task sepcific query embedding
        out1 = self.embed(trg)
        
        for i in range(self.n_layers):
            out1 = self.dec(out1, enc)
        #out = self.linear(out)
        return out1
    
    

class UniT(nn.Module):
    def __init__(self):
        super().__init__()
        #task 1
        self.enc1 = Encoder(4*28*28)
        
        #task 2
        self.enc2 = Encoder(4*28*28)
        

        self.decoder1 = Decoder(784, num_embed=18, embed_dim=784)
        self.decoder2 = Decoder(784, num_embed=18, embed_dim=784)
        self.decoder3 = Decoder(784, num_embed=18, embed_dim=784)
        
        self.lin1 = nn.LazyLinear(3)
        self.lin2 = nn.LazyLinear(3)
        self.lin3 = nn.LazyLinear(3)
        
        
    def forward(self, src, trg):
        out1 = self.enc1(src[0])
        out2 = self.enc2(src[1])
        
        out = torch.cat([out1, out2], axis = 1)
        print('outsss' + str(out.shape))
        
        dec1 = self.decoder1(trg[0], out)
        dec2 = self.decoder2(trg[1], out)
        dec3 = self.decoder2(trg[2], out)        
        
        out1 = self.lin1(dec1)
        out = self.lin2(dec2)
        out3 = self.lin3(dec3)
        
        return [out1, out2, out3]

I understand that there is a misfit between my target and my logit size. However, I’m not quite sure how to approach this. In my example the label consists of three classes.

Would be grateful for guidance.

nn.CrossEntropyLoss expects a model output containing logits in the shape [batch_size, nb_classes, *additional_dims] and targets containing class indices in the range [0, nb_classes-1] in the shape [batch_size, *additional_dims].
Based on the error message it seems your model output has the shape [batch_size, nb_classes, 3] which then expects a target in the shape [batch_size, 3].
Could you explain the output shape in more detail and what the dimensions represent?

I noticed some problems in regard of my shapes (changed the code accordingly). However, I still have dimension misfit when passing the encoder output as an input to my decoder. For this it seems that for query and value expect a 3dimensional output, however I only obtain a 2dimensional one after having passed it through my encoding layers.

In regard of my input that are passed to the architecture, I assume that I have image data of height and width of 28 and with 3 channels in addition to that I make the assumption that I have 3 classes specified via

y1 = torch.randint(0,3,(1000,1))

Thank you in advance.