I’m having trouble setting up a multi-target learning problem where the targets have variable length. I think the problem is how I’ve set up my data, but it could also be in the model as well.
My targets have a binary classification label and a regression value. Different examples may have targets of different length. Below is my Dataset code which takes pandas dataframe and returns a 1D tensor for training examples and dictionary with 1D tensors for the regression target (called locs
) and the class (called labels
). I’ve also printed out an example
dep_vars = ['target_locs', 'target_labels']
ind_vars = [col for col in train_df.columns if col not in dep_vars]
class llni_Dataset(Dataset):
def __init__(self, df):
self.X = df[ind_vars]
self.y = df[dep_vars]
def __len__(self):
return len(self.y)
def __getitem__(self, idx):
# convert X
X = torch.as_tensor(self.X.loc[idx].values, dtype=torch.float32)
# set up the target
num_targets = len(self.y)
target = {}
target['locs'] = torch.as_tensor(self.y.loc[idx, 'target_locs'], dtype=torch.float32)
target['labels'] = torch.as_tensor([1 if 'road' in targ else 0 for targ in self.y.loc[idx, 'target_labels']], dtype=torch.int64)
return X, target
train_ds = llni_Dataset(train_df)
print(train_ds[0])
which returns
(tensor([-1.2727e+01, 1.0000e+01, 1.2515e-02, 1.7460e-02, 3.4232e-02,
4.4352e-02, 4.9399e-02, 6.2449e-02, 7.0045e-02, 0.0000e+00,
3.6364e+00, 2.0909e+01, 2.4545e+01, -1.0000e+02, -1.0000e+02,
-1.0000e+02, -1.0000e+02, -1.0000e+02, -1.0000e+02, -1.0000e+02,
-1.0000e+02, -1.0000e+02, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 2.5490e-04, 7.6469e-04,
1.1470e-03, 7.0096e-03, 1.1725e-02, 1.7460e-02, 1.4274e-02,
7.5194e-03, 1.7843e-03, 1.7843e-03, 5.4802e-03, 2.3450e-02,
5.6714e-02, 7.6978e-02, 5.5057e-02, 3.5558e-02, 4.5881e-02,
7.0096e-02, 6.2449e-02, 3.2244e-02, 9.1762e-03, 2.1666e-03,
6.3724e-04, 1.2745e-04, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 2.5490e-04, 6.3724e-04,
4.8430e-03, 1.4402e-02, 4.4224e-02, 6.7930e-02, 6.8567e-02,
5.2763e-02, 4.7156e-02, 6.7930e-02, 6.9841e-02, 4.4352e-02,
1.8607e-02, 1.2108e-02, 1.4529e-02, 1.4147e-02, 9.9409e-03,
5.2254e-03, 2.1666e-03, 5.0979e-04, 0.0000e+00, 0.0000e+00,
1.2745e-04, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00]),
{'locs': tensor([ 6.3051, -4.5587, 4.8963, 1.4027, -2.2576]),
'labels': tensor([1, 1, 0, 0, 0])})
Next is probably where the error is. The only way I could see that people put together targets with different lengths is to use a collating function like the one below–please let me know if there’s a better way to accomplish this. Using the collating function I put the Dataset into a DataLoader and print out an example batch b
(which we’ll use again below):
def collate_fn(batch):
return tuple(zip(*batch))
train_dl = DataLoader(train_ds, batch_size = 4, shuffle = True, collate_fn = collate_fn)
for b in train_dl:
print(b)
break
which returns:
((tensor([-7.2727e+00, 9.0909e+00, 1.1855e-02, 2.9259e-02, 4.3675e-02,
5.1220e-02, 5.2019e-02, 5.3974e-02, 8.8631e-02, -9.0909e-01,
2.0000e+01, 2.2727e+01, -1.0000e+02, -1.0000e+02, -1.0000e+02,
-1.0000e+02, -1.0000e+02, -1.0000e+02, -1.0000e+02, -1.0000e+02,
-1.0000e+02, -1.0000e+02, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 2.7538e-04, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 6.8845e-04,
9.6383e-04, 5.0945e-03, 2.0929e-02, 5.1633e-02, 9.0600e-02,
1.1043e-01, 7.0910e-02, 5.2460e-02, 5.1909e-02, 4.8329e-02,
3.1531e-02, 1.1979e-02, 3.1669e-03, 6.8845e-04, 6.8845e-04,
2.7538e-04, 0.0000e+00, 1.3769e-04, 1.3769e-04, 6.8845e-04,
4.1307e-04, 1.3769e-03, 4.4061e-03, 1.0877e-02, 1.9965e-02,
3.0292e-02, 4.2133e-02, 5.3561e-02, 5.3974e-02, 5.1220e-02,
5.3974e-02, 6.5815e-02, 6.2786e-02, 4.4061e-02, 2.8226e-02,
1.1841e-02, 6.4714e-03, 3.4422e-03, 6.8845e-04, 2.7538e-04,
1.3769e-04, 1.3769e-04, 1.3769e-04, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 1.3769e-04, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 1.3769e-04, 0.0000e+00,
0.0000e+00]), tensor([-1.0000e+01, 9.0909e+00, 1.3721e-02, 4.6678e-02, 8.1144e-02,
9.3003e-02, 1.0189e-01, 1.1823e-01, 1.7623e-01, -1.8182e+00,
1.8182e+00, -1.0000e+02, -1.0000e+02, -1.0000e+02, -1.0000e+02,
-1.0000e+02, -1.0000e+02, -1.0000e+02, -1.0000e+02, -1.0000e+02,
-1.0000e+02, -1.0000e+02, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 1.4145e-03, 2.3575e-03, 4.9507e-03, 5.8937e-03,
1.6267e-02, 3.3948e-02, 7.3789e-02, 1.3131e-01, 1.9426e-01,
1.6149e-01, 1.0727e-01, 9.3828e-02, 1.1387e-01, 9.2177e-02,
5.0922e-02, 1.0609e-02, 3.5362e-03, 1.1787e-03, 4.7150e-04,
4.7150e-04, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00]), tensor([-1.1818e+01, 6.3636e+00, 2.2130e-02, 4.6158e-02, 7.5597e-02,
8.8660e-02, 1.0229e-01, 1.3275e-01, 1.6093e-01, -2.7273e+00,
-1.0000e+02, -1.0000e+02, -1.0000e+02, -1.0000e+02, -1.0000e+02,
-1.0000e+02, -1.0000e+02, -1.0000e+02, -1.0000e+02, -1.0000e+02,
-1.0000e+02, -1.0000e+02, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 2.3580e-04, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 4.7160e-04,
9.4319e-04, 1.8864e-03, 7.0740e-03, 2.1222e-02, 4.0322e-02,
8.1822e-02, 1.3558e-01, 1.5775e-01, 1.6482e-01, 1.3181e-01,
1.0682e-01, 9.5498e-02, 7.1447e-02, 4.8103e-02, 2.2872e-02,
7.7814e-03, 2.3580e-03, 9.4319e-04, 0.0000e+00, 2.3580e-04,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00]), tensor([-1.0000e+01, 2.1818e+01, 1.4368e-02, 3.1920e-02, 4.4716e-02,
5.0020e-02, 5.5548e-02, 6.2078e-02, 8.0558e-02, -1.8182e+00,
1.8182e+00, 1.0909e+01, -1.0000e+02, -1.0000e+02, -1.0000e+02,
-1.0000e+02, -1.0000e+02, -1.0000e+02, -1.0000e+02, -1.0000e+02,
-1.0000e+02, -1.0000e+02, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 1.1190e-04, 2.2380e-04, 6.7141e-04, 2.0142e-03,
7.6094e-03, 1.7345e-02, 3.8159e-02, 4.7670e-02, 5.4273e-02,
4.6663e-02, 6.3784e-02, 7.8220e-02, 8.3815e-02, 5.6958e-02,
3.1668e-02, 1.3316e-02, 4.4761e-03, 4.0285e-03, 6.6022e-03,
1.4212e-02, 2.9878e-02, 5.6734e-02, 8.0682e-02, 7.7660e-02,
6.9268e-02, 5.6399e-02, 5.2370e-02, 4.3418e-02, 3.2675e-02,
1.7345e-02, 7.9451e-03, 2.3499e-03, 6.7141e-04, 6.7141e-04,
0.0000e+00, 1.1190e-04, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00])), ({'locs': tensor([ 4.9503, -6.5299, 3.3085, -0.1613, -3.5986]), 'labels': tensor([1, 1, 0, 0, 0])}, {'locs': tensor([ 5.3896, -6.0695, 3.9064, 0.3867, -3.2428]), 'labels': tensor([1, 1, 0, 0, 0])}, {'locs': tensor([ 3.1086, -9.3537, 0.5882, -2.7026, -6.3033]), 'labels': tensor([1, 1, 0, 0, 0])}, {'locs': tensor([ 5.0380, -7.7921, 2.1711, -1.4178, -4.9330, 7.1524, 10.0885, 13.6362,
17.3264]), 'labels': tensor([1, 1, 0, 0, 0, 1, 0, 0, 0])}))
I then tried to put together (what I was hoping would be) a very simple fully connected network. NOTE: I gave layers lin4_1
and lin4_2
13 output size because 13 is the maximum length I expect the targets to take-- again not sure if this is the correct approach.
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self, input_size):
super(Net, self).__init__()
self.lin1 = nn.Linear(input_size, 200)
self.lin2 = nn.Linear(200, 100)
self.lin3 = nn.Linear(100, 50)
self.lin4_1 = nn.Linear(50, 13)
self.lin4_2 = nn.Linear(50, 13)
self.bn1 = nn.BatchNorm1d(200)
self.bn2 = nn.BatchNorm1d(100)
self.bn3 = nn.BatchNorm1d(50)
self.drops = nn.Dropout(0.2)
def forward(self, x):
x = self.bn1(F.relu(self.lin1(x)))
x = self.drops(x)
x = self.bn2(F.relu(self.lin2(x)))
x = self.drops(x)
x = self.bn3(F.relu(self.lin3(x)))
x = self.drops(x)
labels = self.lin4_1(x)
locs = self.lin4_2(x)
return locs, labels
net = Net(len(ind_vars))
BUT if I try to run a batch of training examples through the network I get an error about the batch being a tuple and not having an attribute ‘dim’, which I take to be a problem with the method I used to collate. Running
net(b[0])
gives:
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
/tmp/ipykernel_6716/498270831.py in <module>
----> 1 net(b[0])
~/miniconda3/envs/emp_path/lib/python3.9/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
725 result = self._slow_forward(*input, **kwargs)
726 else:
--> 727 result = self.forward(*input, **kwargs)
728 for hook in itertools.chain(
729 _global_forward_hooks.values(),
/tmp/ipykernel_6716/4134123215.py in forward(self, x)
15
16 def forward(self, x):
---> 17 x = self.bn1(F.relu(self.lin1(x)))
18 x = self.drops(x)
19 x = self.bn2(F.relu(self.lin2(x)))
~/miniconda3/envs/emp_path/lib/python3.9/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
725 result = self._slow_forward(*input, **kwargs)
726 else:
--> 727 result = self.forward(*input, **kwargs)
728 for hook in itertools.chain(
729 _global_forward_hooks.values(),
~/miniconda3/envs/emp_path/lib/python3.9/site-packages/torch/nn/modules/linear.py in forward(self, input)
91
92 def forward(self, input: Tensor) -> Tensor:
---> 93 return F.linear(input, self.weight, self.bias)
94
95 def extra_repr(self) -> str:
~/miniconda3/envs/emp_path/lib/python3.9/site-packages/torch/nn/functional.py in linear(input, weight, bias)
1686 if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
1687 return handle_torch_function(linear, tens_ops, input, weight, bias=bias)
-> 1688 if input.dim() == 2 and bias is not None:
1689 # fused op is marginally faster
1690 ret = torch.addmm(bias, input, weight.t())
AttributeError: 'tuple' object has no attribute 'dim'
Does anybody have any insight into what I’m doing wrong and/or how I should be approaching this problem? I’d really appreciate any guidance or advice anybody might have. Thanks a lot in advance.