How to add additional cost to create_supervised_trainer?

(Jesse Livezey) #1

I’m using ignite.engine.create_supervised_trainer() and I would like to train a network using a crossentropy loss plus a function of the model parameters. I want something like

def create_supervised_trainer(model, optimizer, loss_fn,
                              model_loss=None,
                              output_transform=lambda x, y, y_pred, loss: loss.item()):

       def _update(engine, batch):
        model.train()
        optimizer.zero_grad()
        x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
        y_pred = model(x)
        loss = loss_fn(y_pred, y)
        if model_loss is not None:
            loss = loss + model_loss(model)
        loss.backward()
        optimizer.step()
        return output_transform(x, y, y_pred, loss)

    return Engine(_update)

Is there some builtin way to get something similar?

(Vfdev) #2

@JesseLivezey this can be done in a multiple way:

  1. You would like to reuse create_supervised_trainer, you can create a custom function with a signature (y_pred, y)
model = ...
loss_1 = nn.CrossEntropy()
model_loss = ...

def my_custom_loss(y_pred, y):
    return loss_1(y_pred, y) + model_loss(model)

trainer = create_supervised_trainer(model, optimizer, my_custom_loss)
  1. You would like to reuse create_supervised_trainer, you can subclass a loss function as
class ComposedLoss(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.l1 = nn.CrossEntropy()
        self.model_loss = ... 
        self.model = model

    def forward(self, y_pred, y):
         return self.l1(y_pred, y) + self.model_loss(self.model)
        
  1. You can directly create a trainer with you own update function
def update_fn(engine, batch):
        model.train()
        optimizer.zero_grad()
        x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
        y_pred = model(x)
        loss = loss_fn(y_pred, y)
        if model_loss is not None:
            loss = loss + model_loss(model)
        loss.backward()
        optimizer.step()
        return output_transform(x, y, y_pred, loss)

trainer = Engine(update_fn)

HTH

1 Like
(Jesse Livezey) #3

Thanks! That’s helpful.

1 Like