I am working on a multi-output (i.e > 1 output target) multi-class (i.e > 1 class) (I believe this is also called a multi-task problem).
For example, my train_features_data is of shape (4, 6) (i.e three rows/examples and 6 columns/features), and my train_target_data is of shape (4, 3) (i.e 4 rows/examples and 3 columns/targets). For each target I have three different classes (-1, 0, 1).
I define an example model architecture (and data) for this problem like so:
import pandas as pd
from torch import nn
from logging import log
import torch
feature_data = {
'A': [1, 2, 3, 4],
'B': [5, 6, 7, 8],
'C': [9, 10, 11, 12],
'D': [13, 14, 15, 16],
'E': [17, 18, 19, 20],
'F': [21, 22, 23, 24]
}
target_data = {
'Col1': [1, -1, 0, 1],
'Col2': [-1, 0, 1, -1],
'Col3': [-1, 0, 1, 1]
}
# Create the DataFrame
train_feature_data = pd.DataFrame(feature_data)
train_target_data = pd.DataFrame(target_data)
device = "cuda" if torch.cuda.is_available() else "cpu"
# create the model
class MyModel(nn.Module):
def __init__(self, inputs=6, l1=12, outputs=3):
super().__init__()
self.sequence = nn.Sequential(
nn.Linear(inputs, l1),
nn.Linear(l1, outputs),
nn.Softmax(dim=1)
)
def forward(self, x):
x = self.sequence(x)
return x
x_train = torch.tensor(train_feature_data.to_numpy()).type(torch.float)
model = MyModel(inputs = 6, l1 = 12, outputs = 3).to(device)
model(x_train.to(device=device))
When I pass my train data into the model (i.e when i call model(x_train.to(device=device))), I get back an array of shape (4, 3).
By following this resource resource My expectation was that I would get something like (4, 3, 3) whereby the 4 is the number of examples in my features and targets, the middle 3 (i.e the second axis) represents the logits (or in this case because I have a softmax function, this will be the predicted probabilities) of each example (and this would be 3 because I have three classes), while the third axis (or rightmost 3 value in the shape) represents the number of outputs/columns I have in my train_target_data.
Can someone please provide some guidance on what I’m doing incorrectly here (if my approach is wrong) and how to go about fixing it. Thanks.