How to set BCELoss with weight in pytorch lightning

Hi:

I’m training a classifier with pytorch lightning. I have created a LightningDataModule and LightningModule. Where the dataset was balanced, I set the criterion (to compute loss) in the init() function like this

class LightningClassifier(pl.LightningModule):
    def __init__(self, model=None, **kwargs): 
        super().__init__()
        self.criterion = torch.nn.BCELoss()

However, I now have a datamodule that has an unbalanced training dataset and I want to use BCELoss(weight=weight) like this.

class LightningClassifier(pl.LightningModule):
    def __init__(self, model=None, **kwargs): 
        super().__init__()
        weight = datamodule.weight("train")
        self.criterion = torch.nn.BCELoss(weight=weight)

In the above example, it is not possible to do this since when I’m initializing the LightningModule, I do not have access to the datamodule. Where would I set the criterion value in the LightningModule?

I’d appreciate pointers to this.

Thanks,
Aravind.

I can’t figure this out with just the code that you shared.

Hi:

Thanks for the response. My question is a conceptual one so I didn’t want to add too much code initially, but I edited my question and I hope it is clear now.

Thanks,
Aravind.