Hi:
I’m training a classifier with pytorch lightning. I have created a LightningDataModule and LightningModule. Where the dataset was balanced, I set the criterion (to compute loss) in the init() function like this
class LightningClassifier(pl.LightningModule):
def __init__(self, model=None, **kwargs):
super().__init__()
self.criterion = torch.nn.BCELoss()
However, I now have a datamodule that has an unbalanced training dataset and I want to use BCELoss(weight=weight)
like this.
class LightningClassifier(pl.LightningModule):
def __init__(self, model=None, **kwargs):
super().__init__()
weight = datamodule.weight("train")
self.criterion = torch.nn.BCELoss(weight=weight)
In the above example, it is not possible to do this since when I’m initializing the LightningModule, I do not have access to the datamodule. Where would I set the criterion
value in the LightningModule?
I’d appreciate pointers to this.
Thanks,
Aravind.