Pytorch Lightning : Confusion regarding metric logging

Hi, I am a bit confused about metric logging in training_step/validation_step.
Now a standard training_step is

def training_step(self, batch, batch_idx):
        labels=<from somewhere>
        logits = self.forward(batch)
        loss = F.cross_entropy(logits, labels)
        with torch.no_grad():
            correct = (torch.argmax(logits, dim=1) == labels).sum()
            total = len(labels)
            acc = (torch.argmax(logits, dim=1) == labels).float().mean()

        log = dict(train_loss=loss, train_acc=acc, correct=correct, total=total)
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log("train_acc", acc, on_step=True)
        return dict(loss=loss, log=log)

Now here my doubt is, the train_loss that I am logging here, is it the train loss for this particular batch or averaged loos over entire epoch. Now as I have set, on_step and on_epoch both True, what is actually getting logged and when(after each batch or at the epoch end)?

About training_acc, when I have set on_step to True, does it only log the per batch accuracy during training and not the overall epoch accuracy?

Now with this training_step, if I add a custom training_epoch_end like this

def training_epoch_end(self, outputs) -> None:
        correct = 0
        total = 0
        for o in outputs:
            correct += o["log"]["correct"]
            total += o["log"]["total"]
        self.log("train_epoch_acc", correct / total)

Is the train_epoch_acc here same as the average of per batch training_acc?

I intend to put an EarlyStoppingCallBack with monitoring validation loss of the epoch, defined in a same fashion as for train_loss.
If I just put early_stop_callback = pl.callbacks.EarlyStopping(monitor="val_loss", patience=p) , will it monitor per batch val_loss or epoch wise val_loss as logging for val_loss is happening during batch end and epoch end as well.
Sorry if my questions are a little too silly, but I am confused about this!

Thank you!

Did you ever figure this out? I have a similar question about validation_step and validation_epoch_end.

Hi sorry for delayed reply,

You can check this out

Hope this helps!