Pytorch Simaese model using Lstm

I am building a siamese model using Lstm, I have trained and tested the model but I condn’t inference it on sigle sample

Here’s the model

class SiameseLstm(nn.Module):
    def __init__(self):
        super(SiameseLstm, self).__init__()
        self.embedding_model = nn.Embedding(model_vocab_size, embedding_dim)
        self.embedding_model.weight.data.copy_(embedding_matrix_model)
        self.embedding_student = nn.Embedding(student_vocab_size, embedding_dim)
        self.embedding_student.weight.data.copy_(embedding_matrix_student)
        self.lstm = nn.LSTM(input_size=embedding_dim, hidden_size=128, num_layers=2, batch_first=True, bidirectional=True)
        self.flat = nn.Flatten()
        self.dropout = nn.Dropout(0.1)
        self.dense = nn.Linear(544, 1)
        self.activation = nn.Sigmoid()
        
    def forward_model(self, input):
        embedded = self.embedding_model(input)
        out,(h, c) = self.lstm(embedded)
        #hidden = torch.cat((h[-2,:,:], h[-1,:,:]), dim = 1)
        output = out[:, -1, :]
        output = self.flat(output)
        return output
    
    def forward_student(self, input):
        embedded = self.embedding_student(input)
        out,(h, c) = self.lstm(embedded)
        #hidden = torch.cat((h[-2,:,:], h[-1,:,:]), dim = 1)
        output = out[:, -1, :]
        output = self.flat(output)
        return output



    def forward(self, inp1, inp2):
        out1 = self.forward_model(inp1)
        out2 = self.forward_student(inp2)
        x3 = torch.subtract(out1, out2)
        x3 = torch.multiply(x3, x3)
        x1_ = torch.multiply(out1, out1)
        x2_ = torch.multiply(out2, out2)
        x4 = torch.subtract(x1_, x2_)
        x5 = torch.cdist(out1, out2)
        merged = torch.concatenate((x5, x4, x3), dim=-1)
        merged = self.dropout(merged)
        merged = self.dense(merged)
        merged = self.activation(merged)
        return merged
        

train method:

Model.train()
    for i, (inp1, inp2, label) in enumerate(train_dataloader):  
        inp1 = inp1.cuda()
        inp2 = inp2.cuda()
        label = label.cuda()
        optimizer.zero_grad()
        output = Model(inp1, inp2)
        #print('output', output)
        #print('label', label)
        loss = criterion(output, label)
        acc = calculate_acc(output, label)
        TrainLoss += loss.item()
        TrainAcc += acc.item()
        loss.backward()
        optimizer.step()

And this is the error in inference:

runtimeerror: mat1 and mat2 shapes cannot be multiplied (1x513 and 544x1)```