Automatically Select Arithmetic Operations Between Tensors

Hello, I’m developing a model that makes the union between two tensors through arithmetic operations. (Example: addition, subtraction, maximum, minimum, multiplication).

But I need to do all the operations individually to check which one has the lowest loss rate. Is there any way to dynamically compare these operations at runtime and select the best one?

class Model(nn.Module):

		def __init__(self, model_one: nn.Module,
                                            model_two: nn.Module,
                                            in_features_one: int,
                                            in_features_two: int, 
		                            classes: int):
			
			super(Model, self).__init__()

			self.classes = classes
			self.model_one = model_one
			self.model_two = model_two
			
			# Normalize
			self.bn_model_one = nn.BatchNorm1d(in_features_one)
			self.bn_model_two = nn.BatchNorm1d(in_features_two)
			self.bn_fused     = nn.BatchNorm1d(768)
			
			# Layers and FC
			self.layer_1 = torch.nn.Linear(768, 768)
			self.fc      = torch.nn.Linear(768,self.classes)
			
			# Dropout / ReLU / Tanh / Sigmoid
			self.dropout = nn.Dropout(0.1)
			self.relu    = nn.ReLU()
			
		#######################################
		# Forward
		#######################################
		def forward(self, data_model_one, data_model_two, op):

			model_one = self.model_one(data_model_one)
			model_two = self.model_two(data_model_two)

			##############################################################
			# Operations
			##############################################################
			#['cat', 'add', 'sub', '-sub', 'div', 'max', 'min', 'mul', 'pow', '-pow']

			# Concatenate
			if op == 1:
				combined = torch.cat([model_one, model_two], dim=1)
			# Sum
			elif op == 2:
				combined = torch.add(model_one, model_two)
			# Sub
			elif op == 3:
				combined = torch.sub(model_one, model_two)
			# -Sub
			elif op == 4:
				combined = torch.sub(model_two, model_one)
			# Max
			elif op == 5:
				combined = torch.max(model_two, model_one)
			# Min
			elif op == 6:
				combined = torch.min(model_two, model_one)
			# Multi
			elif op == 7:
				combined = torch.mul(model_two, model_one)

			fused = self.dropout(self.bn_fused(self.fusion(combined)))
			fused = self.layer_1(fused)
			fused = self.fc(fused)

			return fused

PyTorch is an eager mode framework which means you can have branches and for loops to take specific runtime behavior. Didn’t try running your code but it should in principle work just fine