@damicoedoardo basical usage of ignite is the following where we highly use global scope variables in majority of handlers:
def main():
train_loader, val_loader = get_dataloader(...)
optimizer = ...
model = ...
criterion = ...
def train_step(engine, batch):
model.train() # model is defined in the scope of main
optimizer.zero_grad() # optimizer is defined in the scope of main
x, y = prepare_batch(batch)
y_pred = model(x)
loss = criterion(y_pred, y) # criterion is defined in the scope of main
loss.backward()
optimizer.step()
return loss.item()
trainer = Engine(train_step)
evaluator = create_supervised_evaluator(model, metrics={'acc': Accuracy()})
@trainer.on(Events.EPOCH_COMPLETED)
def run_validation(engine):
evaluator.run(val_loader)
trainer.run(train_loader, max_epochs=10)
If you wish not to use global scope like above, it is possible to do this like
- define a custom structure
Trainer
and pass necessary things to it
class SpecificModelTrainer:
def __init__(self, model, optimizer, criterion, *args, **kwargs):
# ... setup all to self.*
def train_step(self, engine, batch):
self.model.train() # model is defined in the scope of main
self.optimizer.zero_grad() # optimizer is defined in the scope of main
x, y = self.prepare_batch(batch)
y_pred = self.model(x)
loss = self.criterion(y_pred, y) # criterion is defined in the scope of main
loss.backward()
self.optimizer.step()
return loss.item()
def main():
train_loader, val_loader = get_dataloader(...)
optimizer = ...
model = ...
criterion = ...
model_trainer = SpecificModelTrainer(...)
trainer_engine = Engine(model_trainer.train_step)
...
- use
trainer.state
for that
def main():
train_loader, val_loader = get_dataloader(...)
optimizer = ...
model = ...
criterion = ...
def train_step(engine, batch):
state = engine.state
state.model.train() # model is defined in the scope of main
state.optimizer.zero_grad() # optimizer is defined in the scope of main
x, y = state.prepare_batch(batch)
y_pred = state.model(x)
loss = state.criterion(y_pred, y) # criterion is defined in the scope of main
loss.backward()
state.optimizer.step()
return loss.item()
trainer = Engine(train_step)
evaluator = create_supervised_evaluator(model, metrics={'acc': Accuracy()})
@trainer.on(Events.STARTED, model, optimizer, criterion, prepare_batch)
def setup_state(engine, m, o, c, p):
engine.state.model = m
engine.state.optimizer = o
engine.state.criterion = c
engine.state.prepare_batch = p
@trainer.on(Events.EPOCH_COMPLETED)
def run_validation(engine):
evaluator.run(val_loader)
trainer.run(train_loader, max_epochs=10)
HTH