I’m training a classifier with pytorch lightning. I have created a LightningDataModule and LightningModule. Where the dataset was balanced, I set the criterion (to compute loss) in the init() function like this
class LightningClassifier(pl.LightningModule): def __init__(self, model=None, **kwargs): super().__init__() self.criterion = torch.nn.BCELoss()
However, I now have a datamodule that has an unbalanced training dataset and I want to use
BCELoss(weight=weight) like this.
class LightningClassifier(pl.LightningModule): def __init__(self, model=None, **kwargs): super().__init__() weight = datamodule.weight("train") self.criterion = torch.nn.BCELoss(weight=weight)
In the above example, it is not possible to do this since when I’m initializing the LightningModule, I do not have access to the datamodule. Where would I set the
criterion value in the LightningModule?
I’d appreciate pointers to this.