Ignite, Correct way of using the library, how to pass model to callable

Hi everyone,
I’m new to pytorch and ignite libraries, and I’m trying to figure out which is the best way to use them.
I read almost all the tutorial on the ignite library but still I don’t get which is the correct way of use it.

In every tutorial I see that the function used to define an engine is calling inside model but this last is not passed to it as parameters,
what I would have thought is that the correct way to use ignite is to “hide” everything to the final user, so I would have defined every stuff to do during the training inside a fit function of a certain model class

but since I cannot pass parameters to the functions passed to the engine this result more difficult ("'I’m actually passing the model as optional parameter but kinda feel that it is not the correct way")

can anyone explain me the abstract way to use the ignite library in the correct way ??

Thank you all !

@damicoedoardo basical usage of ignite is the following where we highly use global scope variables in majority of handlers:


def main():

    train_loader, val_loader = get_dataloader(...)
    optimizer = ...
    model = ... 
    criterion = ...

    def train_step(engine, batch):
        model.train()  # model is defined in the scope of main
        optimizer.zero_grad()   # optimizer is defined in the scope of main
        x, y = prepare_batch(batch)
        y_pred = model(x)
        loss = criterion(y_pred, y)   # criterion is defined in the scope of main
        loss.backward()
        optimizer.step()
        return loss.item()        


     trainer = Engine(train_step)
     evaluator = create_supervised_evaluator(model, metrics={'acc': Accuracy()})

    @trainer.on(Events.EPOCH_COMPLETED)
    def run_validation(engine):
        evaluator.run(val_loader)

    trainer.run(train_loader, max_epochs=10)

If you wish not to use global scope like above, it is possible to do this like

  1. define a custom structure Trainer and pass necessary things to it

class SpecificModelTrainer:
    
    def __init__(self, model, optimizer, criterion, *args, **kwargs):
         # ... setup all to self.*

    def train_step(self, engine, batch):
        self.model.train()  # model is defined in the scope of main
        self.optimizer.zero_grad()   # optimizer is defined in the scope of main
        x, y = self.prepare_batch(batch)
        y_pred = self.model(x)
        loss = self.criterion(y_pred, y)   # criterion is defined in the scope of main
        loss.backward()
        self.optimizer.step()
        return loss.item()        
         

def main():
    train_loader, val_loader = get_dataloader(...)
    optimizer = ...
    model = ... 
    criterion = ...

    model_trainer = SpecificModelTrainer(...)

    trainer_engine = Engine(model_trainer.train_step)

    ...
  1. use trainer.state for that
def main():

    train_loader, val_loader = get_dataloader(...)
    optimizer = ...
    model = ... 
    criterion = ...

    def train_step(engine, batch):
        state = engine.state
        state.model.train()  # model is defined in the scope of main
        state.optimizer.zero_grad()   # optimizer is defined in the scope of main
        x, y = state.prepare_batch(batch)
        y_pred = state.model(x)
        loss = state.criterion(y_pred, y)   # criterion is defined in the scope of main
        loss.backward()
        state.optimizer.step()
        return loss.item()        


     trainer = Engine(train_step)
     evaluator = create_supervised_evaluator(model, metrics={'acc': Accuracy()})

    @trainer.on(Events.STARTED, model, optimizer, criterion, prepare_batch)
    def setup_state(engine, m, o, c, p):
         engine.state.model = m
         engine.state.optimizer = o
         engine.state.criterion = c
         engine.state.prepare_batch = p

    @trainer.on(Events.EPOCH_COMPLETED)
    def run_validation(engine):
        evaluator.run(val_loader)

    trainer.run(train_loader, max_epochs=10)

HTH

1 Like

Thank you very much for the explanation and example of usage!
forgot to say, very good work on Ignite guys !

Keep going!

1 Like

Not to mention how good ignite is, I’d suggest you to checkout pytorch lightning as well.

1 Like

And, please, don’t forget others: skorch, fastai, catalyst, …

2 Likes