Hi. I am training a CNN model using torch text and getting a strange error when trying to perform evaluation on my dev set.
I have my data loaded into two CSV’s and have their location stored in pathD and test_path.
def twitter_data(text_field, label_field, **kargs):
train_data = data.TabularDataset(
path=pathD, format='csv',
fields=[('twitter_text', text_field),
('label', label_field)])
dev_data = data.TabularDataset(
path=test_path, format='csv',
fields=[('twitter_text', text_field),
('label', label_field)])
text_field.build_vocab(train_data, dev_data, vectors=GloVe(name='twitter.27B', dim=200))
label_field.build_vocab(train_data, dev_data)
train_iter, dev_iter = data.Iterator.splits(
(train_data, dev_data),
batch_sizes=(args.batch_size, len(dev_data)),
**kargs)
return train_iter, dev_iter
text_field = data.Field(lower=True)
label_field = data.Field(sequential=False)
train_iter, dev_iter = twitter_data(text_field, label_field, device=-1, repeat=False)
This code works fine and returns an iter for both the train and dev set. When I train, I run the following code and all is dandy.
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay = args.weight_decay) scheduler = StepLR(optimizer, step_size = 5, gamma = .25) model.train() for epoch in range(1, args.epochs+1): scheduler.step() for batch in train_iter: ....
Though when I try to call my eval() function, I get an error on the last line here.
def eval(data_iter, model, args):
model.eval()
corrects, avg_loss = 0, 0
for batch in data_iter: # error occurs hereThe stack trace appears as follows.
for batch in data_iter:
File “/usr/local/lib/python3.5/dist-packages/torchtext/data/iterator.py”, line 164, in iter
self.init_epoch()
File “/usr/local/lib/python3.5/dist-packages/torchtext/data/iterator.py”, line 140, in init_epoch
self.create_batches()
File “/usr/local/lib/python3.5/dist-packages/torchtext/data/iterator.py”, line 151, in create_batches
self.batches = batch(self.data(), self.batch_size, self.batch_size_fn)
File “/usr/local/lib/python3.5/dist-packages/torchtext/data/iterator.py”, line 125, in data
xs = sorted(self.dataset, key=self.sort_key)
TypeError: unorderable types: Example() < Example()
The confusing thing is the error persists even when I cheat and use the same csv file for train and dev splits, so there isn’t an issue with the dev set file.