Trouble setting up a multi-target learning problem where the targets have variable length

I’m having trouble setting up a multi-target learning problem where the targets have variable length. I think the problem is how I’ve set up my data, but it could also be in the model as well.

My targets have a binary classification label and a regression value. Different examples may have targets of different length. Below is my Dataset code which takes pandas dataframe and returns a 1D tensor for training examples and dictionary with 1D tensors for the regression target (called locs) and the class (called labels). I’ve also printed out an example

dep_vars  = ['target_locs', 'target_labels']
ind_vars = [col for col in train_df.columns if col not in dep_vars]

class llni_Dataset(Dataset):
    def __init__(self, df):
        self.X = df[ind_vars]
        self.y = df[dep_vars]        
    
    def __len__(self):
        return len(self.y)
    
    def __getitem__(self, idx): 
        # convert X
        X = torch.as_tensor(self.X.loc[idx].values, dtype=torch.float32)
        
        # set up the target
        num_targets = len(self.y)
        target = {}
        target['locs'] = torch.as_tensor(self.y.loc[idx, 'target_locs'], dtype=torch.float32)
        target['labels'] = torch.as_tensor([1 if 'road' in targ else 0 for targ in self.y.loc[idx, 'target_labels']], dtype=torch.int64)        
        
        return X, target

train_ds = llni_Dataset(train_df)

print(train_ds[0])

which returns

(tensor([-1.2727e+01,  1.0000e+01,  1.2515e-02,  1.7460e-02,  3.4232e-02,
          4.4352e-02,  4.9399e-02,  6.2449e-02,  7.0045e-02,  0.0000e+00,
          3.6364e+00,  2.0909e+01,  2.4545e+01, -1.0000e+02, -1.0000e+02,
         -1.0000e+02, -1.0000e+02, -1.0000e+02, -1.0000e+02, -1.0000e+02,
         -1.0000e+02, -1.0000e+02,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  2.5490e-04,  7.6469e-04,
          1.1470e-03,  7.0096e-03,  1.1725e-02,  1.7460e-02,  1.4274e-02,
          7.5194e-03,  1.7843e-03,  1.7843e-03,  5.4802e-03,  2.3450e-02,
          5.6714e-02,  7.6978e-02,  5.5057e-02,  3.5558e-02,  4.5881e-02,
          7.0096e-02,  6.2449e-02,  3.2244e-02,  9.1762e-03,  2.1666e-03,
          6.3724e-04,  1.2745e-04,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  2.5490e-04,  6.3724e-04,
          4.8430e-03,  1.4402e-02,  4.4224e-02,  6.7930e-02,  6.8567e-02,
          5.2763e-02,  4.7156e-02,  6.7930e-02,  6.9841e-02,  4.4352e-02,
          1.8607e-02,  1.2108e-02,  1.4529e-02,  1.4147e-02,  9.9409e-03,
          5.2254e-03,  2.1666e-03,  5.0979e-04,  0.0000e+00,  0.0000e+00,
          1.2745e-04,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00]),
 {'locs': tensor([ 6.3051, -4.5587,  4.8963,  1.4027, -2.2576]),
  'labels': tensor([1, 1, 0, 0, 0])})

Next is probably where the error is. The only way I could see that people put together targets with different lengths is to use a collating function like the one below–please let me know if there’s a better way to accomplish this. Using the collating function I put the Dataset into a DataLoader and print out an example batch b (which we’ll use again below):

def collate_fn(batch):
    return tuple(zip(*batch))

train_dl = DataLoader(train_ds, batch_size = 4, shuffle = True, collate_fn = collate_fn)

for b in train_dl:
    print(b)
    break

which returns:

((tensor([-7.2727e+00,  9.0909e+00,  1.1855e-02,  2.9259e-02,  4.3675e-02,
         5.1220e-02,  5.2019e-02,  5.3974e-02,  8.8631e-02, -9.0909e-01,
         2.0000e+01,  2.2727e+01, -1.0000e+02, -1.0000e+02, -1.0000e+02,
        -1.0000e+02, -1.0000e+02, -1.0000e+02, -1.0000e+02, -1.0000e+02,
        -1.0000e+02, -1.0000e+02,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  2.7538e-04,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  6.8845e-04,
         9.6383e-04,  5.0945e-03,  2.0929e-02,  5.1633e-02,  9.0600e-02,
         1.1043e-01,  7.0910e-02,  5.2460e-02,  5.1909e-02,  4.8329e-02,
         3.1531e-02,  1.1979e-02,  3.1669e-03,  6.8845e-04,  6.8845e-04,
         2.7538e-04,  0.0000e+00,  1.3769e-04,  1.3769e-04,  6.8845e-04,
         4.1307e-04,  1.3769e-03,  4.4061e-03,  1.0877e-02,  1.9965e-02,
         3.0292e-02,  4.2133e-02,  5.3561e-02,  5.3974e-02,  5.1220e-02,
         5.3974e-02,  6.5815e-02,  6.2786e-02,  4.4061e-02,  2.8226e-02,
         1.1841e-02,  6.4714e-03,  3.4422e-03,  6.8845e-04,  2.7538e-04,
         1.3769e-04,  1.3769e-04,  1.3769e-04,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  1.3769e-04,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  0.0000e+00,  1.3769e-04,  0.0000e+00,
         0.0000e+00]), tensor([-1.0000e+01,  9.0909e+00,  1.3721e-02,  4.6678e-02,  8.1144e-02,
         9.3003e-02,  1.0189e-01,  1.1823e-01,  1.7623e-01, -1.8182e+00,
         1.8182e+00, -1.0000e+02, -1.0000e+02, -1.0000e+02, -1.0000e+02,
        -1.0000e+02, -1.0000e+02, -1.0000e+02, -1.0000e+02, -1.0000e+02,
        -1.0000e+02, -1.0000e+02,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  1.4145e-03,  2.3575e-03,  4.9507e-03,  5.8937e-03,
         1.6267e-02,  3.3948e-02,  7.3789e-02,  1.3131e-01,  1.9426e-01,
         1.6149e-01,  1.0727e-01,  9.3828e-02,  1.1387e-01,  9.2177e-02,
         5.0922e-02,  1.0609e-02,  3.5362e-03,  1.1787e-03,  4.7150e-04,
         4.7150e-04,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00]), tensor([-1.1818e+01,  6.3636e+00,  2.2130e-02,  4.6158e-02,  7.5597e-02,
         8.8660e-02,  1.0229e-01,  1.3275e-01,  1.6093e-01, -2.7273e+00,
        -1.0000e+02, -1.0000e+02, -1.0000e+02, -1.0000e+02, -1.0000e+02,
        -1.0000e+02, -1.0000e+02, -1.0000e+02, -1.0000e+02, -1.0000e+02,
        -1.0000e+02, -1.0000e+02,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  0.0000e+00,  2.3580e-04,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  4.7160e-04,
         9.4319e-04,  1.8864e-03,  7.0740e-03,  2.1222e-02,  4.0322e-02,
         8.1822e-02,  1.3558e-01,  1.5775e-01,  1.6482e-01,  1.3181e-01,
         1.0682e-01,  9.5498e-02,  7.1447e-02,  4.8103e-02,  2.2872e-02,
         7.7814e-03,  2.3580e-03,  9.4319e-04,  0.0000e+00,  2.3580e-04,
         0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00]), tensor([-1.0000e+01,  2.1818e+01,  1.4368e-02,  3.1920e-02,  4.4716e-02,
         5.0020e-02,  5.5548e-02,  6.2078e-02,  8.0558e-02, -1.8182e+00,
         1.8182e+00,  1.0909e+01, -1.0000e+02, -1.0000e+02, -1.0000e+02,
        -1.0000e+02, -1.0000e+02, -1.0000e+02, -1.0000e+02, -1.0000e+02,
        -1.0000e+02, -1.0000e+02,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  1.1190e-04,  2.2380e-04,  6.7141e-04,  2.0142e-03,
         7.6094e-03,  1.7345e-02,  3.8159e-02,  4.7670e-02,  5.4273e-02,
         4.6663e-02,  6.3784e-02,  7.8220e-02,  8.3815e-02,  5.6958e-02,
         3.1668e-02,  1.3316e-02,  4.4761e-03,  4.0285e-03,  6.6022e-03,
         1.4212e-02,  2.9878e-02,  5.6734e-02,  8.0682e-02,  7.7660e-02,
         6.9268e-02,  5.6399e-02,  5.2370e-02,  4.3418e-02,  3.2675e-02,
         1.7345e-02,  7.9451e-03,  2.3499e-03,  6.7141e-04,  6.7141e-04,
         0.0000e+00,  1.1190e-04,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00])), ({'locs': tensor([ 4.9503, -6.5299,  3.3085, -0.1613, -3.5986]), 'labels': tensor([1, 1, 0, 0, 0])}, {'locs': tensor([ 5.3896, -6.0695,  3.9064,  0.3867, -3.2428]), 'labels': tensor([1, 1, 0, 0, 0])}, {'locs': tensor([ 3.1086, -9.3537,  0.5882, -2.7026, -6.3033]), 'labels': tensor([1, 1, 0, 0, 0])}, {'locs': tensor([ 5.0380, -7.7921,  2.1711, -1.4178, -4.9330,  7.1524, 10.0885, 13.6362,
        17.3264]), 'labels': tensor([1, 1, 0, 0, 0, 1, 0, 0, 0])}))

I then tried to put together (what I was hoping would be) a very simple fully connected network. NOTE: I gave layers lin4_1 and lin4_2 13 output size because 13 is the maximum length I expect the targets to take-- again not sure if this is the correct approach.

import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self, input_size):
        super(Net, self).__init__()
        self.lin1 = nn.Linear(input_size, 200)
        self.lin2 = nn.Linear(200, 100)
        self.lin3 = nn.Linear(100, 50)
        self.lin4_1 = nn.Linear(50, 13)
        self.lin4_2 = nn.Linear(50, 13)
        
        self.bn1 = nn.BatchNorm1d(200)
        self.bn2 = nn.BatchNorm1d(100)
        self.bn3 = nn.BatchNorm1d(50)
        
        self.drops = nn.Dropout(0.2)
        
    def forward(self, x):
        x = self.bn1(F.relu(self.lin1(x)))
        x = self.drops(x)
        x = self.bn2(F.relu(self.lin2(x)))
        x = self.drops(x)
        x = self.bn3(F.relu(self.lin3(x)))
        x = self.drops(x)
        
        labels = self.lin4_1(x)
        locs = self.lin4_2(x)
    
        return locs, labels

net = Net(len(ind_vars))

BUT if I try to run a batch of training examples through the network I get an error about the batch being a tuple and not having an attribute ‘dim’, which I take to be a problem with the method I used to collate. Running

net(b[0])

gives:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
/tmp/ipykernel_6716/498270831.py in <module>
----> 1 net(b[0])

~/miniconda3/envs/emp_path/lib/python3.9/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    725             result = self._slow_forward(*input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(
    729                 _global_forward_hooks.values(),

/tmp/ipykernel_6716/4134123215.py in forward(self, x)
     15 
     16     def forward(self, x):
---> 17         x = self.bn1(F.relu(self.lin1(x)))
     18         x = self.drops(x)
     19         x = self.bn2(F.relu(self.lin2(x)))

~/miniconda3/envs/emp_path/lib/python3.9/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    725             result = self._slow_forward(*input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(
    729                 _global_forward_hooks.values(),

~/miniconda3/envs/emp_path/lib/python3.9/site-packages/torch/nn/modules/linear.py in forward(self, input)
     91 
     92     def forward(self, input: Tensor) -> Tensor:
---> 93         return F.linear(input, self.weight, self.bias)
     94 
     95     def extra_repr(self) -> str:

~/miniconda3/envs/emp_path/lib/python3.9/site-packages/torch/nn/functional.py in linear(input, weight, bias)
   1686         if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
   1687             return handle_torch_function(linear, tens_ops, input, weight, bias=bias)
-> 1688     if input.dim() == 2 and bias is not None:
   1689         # fused op is marginally faster
   1690         ret = torch.addmm(bias, input, weight.t())

AttributeError: 'tuple' object has no attribute 'dim'

Does anybody have any insight into what I’m doing wrong and/or how I should be approaching this problem? I’d really appreciate any guidance or advice anybody might have. Thanks a lot in advance.

If the model’s output is already padded to the maximum length of the targets, can you simply pad the targets to this length initially? If they are the same length then the standard collation functions should just work.

Yes, definitely could. A) I was hoping to do it this way and B) I’m not sure what the best padding scheme would be

I’m not sure the padding scheme would be all that important unless there are things like time dependencies in the data/labels or other constraints that you wish to enforce.

Example padding function:
torch.nn.utils.rnn.pad_sequence — PyTorch 1.9.0 documentation

Awesome thank you so much, I will definitely proceed with this strategy. However, for my own edification, I know it is doable to train on targets with variable length, and I would really like to learn how to do that correctly. If anybody has any insight into how to accomplish that I would be very grateful for the help.