I created a model like resnet but for tabular data here is model architecture
class Model(nn.Module): # <-- Update
def __init__(self, num_features, num_targets, hidden_size):
super(Model, self).__init__()
head1 = nn.Sequential(
nn.BatchNorm1d(num_features),
nn.Dropout(0.2),
nn.utils.weight_norm(nn.Linear(num_features, hidden_size)),
nn.ReLU()
)
head2 = nn.Sequential(
nn.BatchNorm1d(hidden_size),
nn.Dropout(0.25),
nn.utils.weight_norm(nn.Linear(hidden_size, hidden_size)),
nn.ReLU()
)
head3 = nn.Sequential(
nn.BatchNorm1d(hidden_size),
nn.Dropout(0.3),
nn.utils.weight_norm(nn.Linear(hidden_size, num_targets))
)
def forward(self, x):
head1_out = head1(x)
concat_1 = torch.cat((x, head1_out), dim=1)
head2_out = head2(concat)
concat_2 = torch.cat((x, head2_out), dim=1)
out = head3(concat_2)
return out
When I initialize it it doesn’t give me any error but doesn’t show any output either.
Any help or suggestions would be appreciated.
Thanks.