class wrapper(nn.Module):
def init(self, input_size, combined_input_size, hidden_size, num_classes, number_of_vars):
super().init()
self.m1 = var(input_size=input_size[0], hidden_size=hidden_size, num_classes=num_classes, p=0.0)
self.m2 = var(input_size=input_size[1], hidden_size=hidden_size, num_classes=num_classes, p=0.0)
self.m3 = var(input_size=input_size[2], hidden_size=hidden_size, num_classes=num_classes, p=0.0)
self.m4 = var(input_size=input_size[3], hidden_size=hidden_size, num_classes=num_classes, p=0.0)
self.m5 = var(input_size=input_size[4], hidden_size=hidden_size, num_classes=num_classes, p=0.0)
def forward(self, var_list):
x = F.normalize(self.m1(var_list[0]), p=2, dim=1)
y = F.normalize(self.m2(var_list[1]), p=2, dim=1)
z = F.normalize(self.m3(var_list[2]), p=2, dim=1)
h = F.normalize(self.m4(var_list[3]), p=2, dim=1)
w = F.normalize(self.m5(var_list[3]), p=2, dim=1)
return torch.cat([x, y, z, h, w], dim=1)
How can I make all this calls inside a loop? or is there a better way to do this?