I’m using ignite.engine.create_supervised_trainer() and I would like to train a network using a crossentropy loss plus a function of the model parameters. I want something like
def create_supervised_trainer(model, optimizer, loss_fn,
model_loss=None,
output_transform=lambda x, y, y_pred, loss: loss.item()):
def _update(engine, batch):
model.train()
optimizer.zero_grad()
x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
y_pred = model(x)
loss = loss_fn(y_pred, y)
if model_loss is not None:
loss = loss + model_loss(model)
loss.backward()
optimizer.step()
return output_transform(x, y, y_pred, loss)
return Engine(_update)
Is there some builtin way to get something similar?
You can directly create a trainer with you own update function
def update_fn(engine, batch):
model.train()
optimizer.zero_grad()
x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
y_pred = model(x)
loss = loss_fn(y_pred, y)
if model_loss is not None:
loss = loss + model_loss(model)
loss.backward()
optimizer.step()
return output_transform(x, y, y_pred, loss)
trainer = Engine(update_fn)