How to make custom callback (stop training at certain value)

Hello all

I am beginner in deep learning and doing research in keras dan pytorch. I am making custom stop training by watching the loss value and make condition I have succeeded to make a custom stop training in keras and i want make the same thing in pytorch but i am facing problems. First how should i write loss metrics for my stop training function and after I make custom function for callbacks how i called that function.

I am considering using ignite and make a custom but i dont know how to do it

this is my custom stop training in Keras

class StopatLossValue(Callbacks):
    def on_batch_end(self, batch, logs={}):
        THR = 0.1
        if logs.get('loss') < np.square(THR):
            self.model.stop_training = True

Below is my pytorch model

# CREATE MODEL
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(784, 100),
           nn.Sigmoid(),
           nn.Linear(100, 10)
        )
        
    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = self.layers(x)
        return x

model = MLP()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
loss_fn = nn.CrossEntropyLoss()

for epoch in range(epochs):
    model.train()
    
    train_losses = []
    valid_losses = []
    for i, (images, labels) in enumerate(train_loader): #Loop for every training ENUMERATE (one epoch)
        
        optimizer.zero_grad()
        
        outputs = model(images) 
        loss = loss_fn(outputs, labels) 
        loss.backward() 
        optimizer.step()
        
        train_losses.append(loss.item())
        
        if (i * 128) % (128 * 100) == 0:
            print(f'{i * 128} / 50000') 

Hi,

In your pytorch model, you already have access to both loss at every iteration and train_losses at the end of the epoch.
So you don’t need a callback actually in pytorch. You can write a function with whatever you need directly.

thank you for your reply, i tried write and there is no error , but i am not sure whether is right or not.

class StopatLossValue:
    def on_epoch_end(self, train_losses, **kwargs):
        # if the monitored metrics got worst set a flag to stop training
         eo = 0.1
         if train_losses < eo.pow(2):
             return{'stop_training': True}

and called it in my model

  StopatLossValue()