I am training two models end-to-end and want to fuse the last layers of both models using a dot product between tensors.
I have tried concatenation, element-wise addition, and matrix multiplication so far.
But I’m having trouble with the dot product!
Here’s my model, which is a simple MLP with three hidden layers performing concatenation. I’d like to replace the concatenation operation with a dot product.
import torch.nn as nn class MulticlassClassification(nn.Module): def __init__(self, num_feature, num_class=3): super(MulticlassClassification, self).__init__() self.layer_1 = nn.Linear(num_feature, 500) self.layer_2 = nn.Linear(500, 500) self.layer_3 = nn.Linear(500, 256) self.layer_out = nn.Linear(256, num_class) self.relu = nn.ReLU() self.dropout = nn.Dropout(p=0.3) def forward(self, x): x = self.layer_1(x) x = self.relu(x) x = self.layer_2(x) x = self.relu(x) x = self.dropout(x) x = self.layer_3(x) x = self.relu(x) x = self.dropout(x) x = self.layer_out(x) return x model_EHR = MulticlassClassification(3) model_G = MulticlassClassification(79) class MyEnsemble(nn.Module): def __init__(self, model_EHR, model_G, nb_classes=3): super(MyEnsemble, self).__init__() self.model_EHR = model_EHR self.model_G = model_G # Remove last linear layer self.model_EHR.layer_out = nn.Identity() self.model_G.layer_out = nn.Identity() # Create new classifier self.layer_out = nn.Linear(512, nb_classes) def forward(self, x1, x2): x1 = self.model_EHR(x1) x1 = x1.view(x1.size(0), -1) #this one if I wnat to perform multiplication x2 = self.model_G(x2) x2 = x2.view(x2.size(0), -1) x = torch.cat((x1, x2), dim=1) x = self.layer_out(F.relu(x)) return x def model() -> MyEnsemble: model = MyEnsemble(model_EHR, model_G) return model DEVICE = "cuda" if torch.cuda.is_available() else "cpu" from torchsummary import summary model = model() print(model) model.to(device=DEVICE,dtype=torch.float) summary(model, [(1, 3), (1,79)]) # dimension of first model (1, 3), dimension of second model (1,79)
Output of model.summary
I tried this approach, which worked for randomly generated tensors but not for my model.
x1=torch.randn([1, 256]) x2=torch.randn([1, 256]) Transpose_x2 = torch.transpose(x2, 0, 1) fusion =torch.bmm(x1.unsqueeze(0),Transpose_x2.unsqueeze(0))
Any suggestion, please?