Hello, I’m developing a model that makes the union between two tensors through arithmetic operations. (Example: addition, subtraction, maximum, minimum, multiplication).
But I need to do all the operations individually to check which one has the lowest loss rate. Is there any way to dynamically compare these operations at runtime and select the best one?
class Model(nn.Module):
def __init__(self, model_one: nn.Module,
model_two: nn.Module,
in_features_one: int,
in_features_two: int,
classes: int):
super(Model, self).__init__()
self.classes = classes
self.model_one = model_one
self.model_two = model_two
# Normalize
self.bn_model_one = nn.BatchNorm1d(in_features_one)
self.bn_model_two = nn.BatchNorm1d(in_features_two)
self.bn_fused = nn.BatchNorm1d(768)
# Layers and FC
self.layer_1 = torch.nn.Linear(768, 768)
self.fc = torch.nn.Linear(768,self.classes)
# Dropout / ReLU / Tanh / Sigmoid
self.dropout = nn.Dropout(0.1)
self.relu = nn.ReLU()
#######################################
# Forward
#######################################
def forward(self, data_model_one, data_model_two, op):
model_one = self.model_one(data_model_one)
model_two = self.model_two(data_model_two)
##############################################################
# Operations
##############################################################
#['cat', 'add', 'sub', '-sub', 'div', 'max', 'min', 'mul', 'pow', '-pow']
# Concatenate
if op == 1:
combined = torch.cat([model_one, model_two], dim=1)
# Sum
elif op == 2:
combined = torch.add(model_one, model_two)
# Sub
elif op == 3:
combined = torch.sub(model_one, model_two)
# -Sub
elif op == 4:
combined = torch.sub(model_two, model_one)
# Max
elif op == 5:
combined = torch.max(model_two, model_one)
# Min
elif op == 6:
combined = torch.min(model_two, model_one)
# Multi
elif op == 7:
combined = torch.mul(model_two, model_one)
fused = self.dropout(self.bn_fused(self.fusion(combined)))
fused = self.layer_1(fused)
fused = self.fc(fused)
return fused