Loading Multioutput text data


(James Callinicos) #1

Hey Team

Longtime listener on this forum, but first post!

I’m currently have issues trying to use torchtext dataloaders to load multi output text data from CSVs. In particular the error occurs during training when trying to iterate through the Iterators produced. But I think the actual issue may be fixed earlier on the code, so I’ll do my best to give a complete view.

An example CSV may look like this:

+---------------------------------------------------------------------------------+----------+----------+
|                                                                            text | labelOne | LabelTwo |
+---------------------------------------------------------------------------------+----------+----------+
|                                                  rugby is a good sport to watch | sport    | rugby    |
+---------------------------------------------------------------------------------+----------+----------+
| Uber files papers to go public, setting up the year's most anticipated tech IPO | business | IPO      |
+---------------------------------------------------------------------------------+----------+----------+

Where the model is trained to predict labelOne and labelTwo.

I load the data using the following piece of code:

TEXT = data.Field(sequential=True, tokenize=tokenizer, lower=True)
L1_LABEL = Field(sequential=False, use_vocab=False)
L2_LABEL = Field(sequential=False, use_vocab=False)

tv_datafields = [("id", None),
                 ("text", TEXT),
                ("labelOne", L1_LABEL),
                ("labelTwo", L2_LABEL)]

test_datafields = [("id", None),
                 ("comment", TEXT),
                ("labelOne ", None),
                ("labelTwo", None)]

train = TabularDataset(path='train.csv',
                      format='csv',
                      fields=tv_datafields,
                      skip_header=True)

test = TabularDataset(path='test.csv',
                     format='csv',
                     fields=test_datafields,
                     skip_header=True)

I then create an Iterator with the following:

train_iter = BucketIterator(train,
                            batch_size=12,
                            sort_key=lambda x: len(comment.text),
                            sort_within_batch=False,
                            repeat=False
                           )

test_iter = Iterator(test, batch_size=12, sort_within_batch=False, repeat=False)

When training the model, I get the following error message:

ValueError                                Traceback (most recent call last)
<ipython-input-179-a309dedac888> in <module>
      7     model.train()
      8 
----> 9     for x, y1, y2 in train_iter:
     10         opt.zero_grad()
     11 

ValueError: not enough values to unpack (expected 3, got 2)

The complete training code is:

for epoch in range(n_epochs):
    
    running_loss = 0.0
    running_l1_corrects = 0
    running_l2_corrects = 0
    
    model.train()
    
    for x, y1, y2 in train_iter:        
        opt.zero_grad()
        
        l1_pred, l2_pred = model(x)
              
        loss1 = loss_func(out1, target1)
        loss2 = loss_func(out2, target2)
        loss = loss1 + loss2
        loss.backward()
        opt.step()
        
        running_loss += loss.data[0] * x.size(0)
    epoch_loss = running_loss / len(train)

Any ideas or solutions would be highly appreciated! Alternatively, any new methods of loading data outside of torchtext would be perfect as well.

Thanks heaps in advance.


(Chris) #2

Maybe your CSV is a bit off, for example, in a line one of the fields is empty or a separator is missing. Does it work if you use a copy of your CSV file containing only the first 10 lines (and where you can easily check if all lines look correct). At least you can tell if it’s a problem with the file, and not with torchtext.

For testing the CSV file, you could also try pandas. When loading a file pandas often gives more useful errors when something with the file is off. In case pandas throws indeed an error, you can also use it to clean your CVS:

  • Load your CSV file into a data frame with error_bad_lines=False to skip erroneous lines
  • Save data frame to a new CSV file which you that can feed to torchtext