Using ignite with torchtext

Is it possible to use ignite with torchtext similar to the MNIST example here? In this example, once you have constructed your DataLoader and model, the code is essentially

model = Net()
opt = optim.SGD(model.parameters(), lr=0.01)
criterion = nn.NLLLoss()

trainer = create_supervised_trainer(model, opt, criterion)
trainer.run(data_loader, max_epochs=10)

with some additional for logging if desired.

If you follow an example from torchtext with a similar “input-output” structure such as this one

TEXT = data.Field()
LABELS = data.Field()

train, val, test = data.TabularDataset.splits(
    path='/data/pos_wsj/pos_wsj', train='_train.tsv',
    validation='_dev.tsv', test='_test.tsv', format='tsv',
    fields=[('text', TEXT), ('labels', LABELS)])

train_iter, val_iter, test_iter = data.BucketIterator.splits(
    (train, val, test), batch_sizes=(16, 256, 256),
    sort_key=lambda x: len(x.text), device=0)

TEXT.build_vocab(train)
LABELS.build_vocab(train)

In this case, as far as I can tell, you can’t pass train_iter to trainer as written above because each batch from train_iter is not a tuple as is in the MNIST case. Is there a way to accommodate this difference with ignite?

It seems to me that the only option is to make use of the custom _update and _inference functions as outlined here, but that requires about as much code as not using ignite at all.

@jacobcvt12 I’m not very familiar with nlp tasks and torchtext, however I’ll try to explain the way we would like to use ignite. As this library is under active development, currently version 0.1.0, we can take a look how to adapt ignite to cover nlp problems too. As always PR are welcome.

At first, ignite offers modularity of the training loop and flexibilty to interact at different steps during the training. In addition, various handlers (ModelCheckpoint, EarlyStopping, metrics computation) are provided out of box to simplify user code.

Concerning the trainer (an instance of Engine), in general case we need to code _update function in order to work with a particular task. Helper functions as create_supervised_trainer are handy for vanilla image classification tasks and other similar task where batch is (batch_x, batch_y), single loss function etc.

Later I’ll take a look at this example and could provide a port with ignite to better estimate usefulness of ignite vs code as without ignite :slight_smile:

HTH

Thanks for the reply @vfdev-5. I put together some code to show what a batch looks like in torchtext.

import torch
from torch import nn, optim
import torch.nn.functional as F
from torchtext.data import Field, BucketIterator
from torchtext.datasets import IMDB

# load IMDB reviews and ratings data 
# (built into torchtext similar to MNIST with torchvision)

## define fields in data
TEXT = Field(lower=True)
LABEL = Field(sequential=False)

## import datasets
train, test = IMDB.splits(TEXT, LABEL, root="../data")
TEXT.build_vocab(train, min_freq=20)
LABEL.build_vocab(train)
train_iter, test_iter = BucketIterator.splits((train, test),
                                              batch_size=100,
                                              repeat=False)

# show what a batch looks like
for batch in train_iter:
    break

data, target = batch.text.t(), batch.label - 1

At the very bottom of the code, I grab a batch from the iterator. As opposed to the batch in MNIST, this is not a tuple, and instead there are fields within the batch that need to be extracted.

I noticed in the “Concepts” documentation in the second code block there is the following line

x, y = prepare_batch(batch)

Is it possible to define prepare_batch? If so, that would really simplify this task.

Thanks, I see better this example and the problem you are facing. At the moment, we can not just use create_supervised_trainer with a custom prepare_batch function without a hack like replacing it by a custom one like :

ignite.engine._prepare_batch =  _foo_prep_batch

Could you please provide a basic model and loss function for this example to see whether it is only prepare_batch that you need to change ?
In general case we suggest to define processing_function (as _update in create_supervised_trainer) for Engine, however we were already asked several times to have such possibility to reuse create_supervised_trainer with a custom _prepare_batch. Let us see if we could find a flexible solution…

Thanks @vfdev-5. That makes sense.

I added a model, loss, etc to the above code:

import torch
from torch import nn, optim
import torch.nn.functional as F
from torchtext.data import Field, BucketIterator
from torchtext.datasets import IMDB

# load IMDB reviews and ratings data 
# (built into torchtext similar to MNIST with torchvision)

## define fields in data
TEXT = Field(lower=True)
LABEL = Field(sequential=False)

## import datasets
train, test = IMDB.splits(TEXT, LABEL, root="../data")
TEXT.build_vocab(train, min_freq=20)
LABEL.build_vocab(train)
train_iter, test_iter = BucketIterator.splits((train, test),
                                              batch_size=100,
                                              repeat=False)

nwords = len(TEXT.vocab)
ntags = 2


class BOW(nn.Module):
    """Simple bag of words model"""
    def __init__(self, nwords, ntags):
        super(BOW, self).__init__()
        self.bias = nn.Parameter(torch.randn(ntags))
        self.embeddings = nn.Embedding(nwords, ntags)

    def forward(self, input):
        output = self.embeddings(input).sum(dim=1)
        output = output + self.bias

        return F.log_softmax(output, dim=1)

model = BOW(nwords, ntags)
opt = optim.SGD(model.parameters(), lr=0.01)
criterion = nn.NLLLoss()

for epoch in range(2):
    for batch_idx, batch in enumerate(train_iter):
        model.train()
        opt.zero_grad()
        x, y = batch.text.t(), batch.label - 1 # only difference from `_update`
        y_pred = model(x)
        loss = criterion(y_pred, y)
        loss.backward()
        opt.step()

        if batch_idx % 10 == 0:
            print("Epoch %d batch %d" % (epoch, batch_idx)) 

The only difference from the “typical” update step that I see is the transformation of the batch to input and output.

Okay, I see, thanks ! Let us take a look on how could we handle such situations.

And I hope that you could however appreciate the modularity and flexibilty of ignite this copy-paste issue for the trainer and evaluator.

And do not hesitate to open issues for feature requests in the github if you have some other problems. And as always, PRs are welcome :slight_smile:

1 Like

So, complete example with ignite will be

import torch
from torch import nn, optim
import torch.nn.functional as F
from torchtext.data import Field, BucketIterator
from torchtext.datasets import IMDB

from ignite.engine import Events, Engine, create_supervised_evaluator
from ignite.metrics import CategoricalAccuracy, Loss


## define fields in data
TEXT = Field(lower=True)
LABEL = Field(sequential=False)

## import datasets
train, test = IMDB.splits(TEXT, LABEL, root="../data")
TEXT.build_vocab(train, min_freq=20)
LABEL.build_vocab(train)
train_loader, test_loader = BucketIterator.splits((train, test), 
                                                  batch_size=100,
                                                  repeat=False)

nwords = len(TEXT.vocab)
ntags = 2


class BOW(nn.Module):
    """Simple bag of words model"""
    def __init__(self, nwords, ntags):
        super(BOW, self).__init__()
        self.bias = nn.Parameter(torch.randn(ntags))
        self.embeddings = nn.Embedding(nwords, ntags)

    def forward(self, input):
        output = self.embeddings(input).sum(dim=1)
        output = output + self.bias
        return F.log_softmax(output, dim=1)


model = BOW(nwords, ntags).to("cuda")
opt = optim.SGD(model.parameters(), lr=0.001)
criterion = nn.NLLLoss()


def train_update(engine, batch):
    model.train()
    opt.zero_grad()
    x, y = batch.text.t(), batch.label - 1
    y_pred = model(x)
    loss = criterion(y_pred, y)
    loss.backward()
    opt.step()
    return loss.item()    


trainer = Engine(train_update)


def evaluate_update(engine, batch):
    model.eval()
    with torch.no_grad():
        x, y = batch.text.t(), batch.label - 1
        y_pred = model(x)
        return y_pred, y    


evaluator = Engine(evaluate_update)

metrics={'accuracy': CategoricalAccuracy(), 'loss': Loss(criterion)}

for name, metric in metrics.items():
    metric.attach(evaluator, name)

log_interval = 10    

@trainer.on(Events.ITERATION_COMPLETED)
def log_training_loss(engine):
    iter = (engine.state.iteration - 1) % len(train_loader) + 1
    if iter % log_interval == 0:
        print("Epoch[{}] Iteration[{}/{}] Loss: {:.2f}"
              "".format(engine.state.epoch, iter, len(train_loader), engine.state.output))


@trainer.on(Events.EPOCH_COMPLETED)
def log_training_results(engine):
    evaluator.run(train_loader)
    metrics = evaluator.state.metrics
    avg_accuracy = metrics['accuracy']
    avg_nll = metrics['loss']
    print("Training Results - Epoch: {}  Avg accuracy: {:.2f} Avg loss: {:.2f}"
          .format(engine.state.epoch, avg_accuracy, avg_nll))


@trainer.on(Events.EPOCH_COMPLETED)
def log_validation_results(engine):
    evaluator.run(test_loader)
    metrics = evaluator.state.metrics
    avg_accuracy = metrics['accuracy']
    avg_nll = metrics['loss']
    print("Validation Results - Epoch: {}  Avg accuracy: {:.2f} Avg loss: {:.2f}"
          .format(engine.state.epoch, avg_accuracy, avg_nll))
    

trainer.run(train_loader, max_epochs=5)
1 Like

Finally, we decided to add prepare_batch to helper functions create_supervised_*. For example, see here.

So above examples can use create_supervised_* with a custom prepare_batch function.

2 Likes

This looks great @vfdev-5! Thanks so much for your work on ignite!

1 Like