How to set BCELoss with weight in pytorch lightning

Hi:

I’m training a classifier with pytorch lightning. I have created a LightningDataModule and LightningModule. Where the dataset was balanced, I set the criterion (to compute loss) in the init() function like this

class LightningClassifier(pl.LightningModule):
    def __init__(self, model=None, **kwargs): 
        super().__init__()
        self.criterion = torch.nn.BCELoss()

However, I now have a datamodule that has an unbalanced training dataset and I want to use BCELoss(weight=weight) like this.

class LightningClassifier(pl.LightningModule):
    def __init__(self, model=None, **kwargs): 
        super().__init__()
        weight = datamodule.weight("train")
        self.criterion = torch.nn.BCELoss(weight=weight)

In the above example, it is not possible to do this since when I’m initializing the LightningModule, I do not have access to the datamodule. Where would I set the criterion value in the LightningModule?

I’d appreciate pointers to this.

Thanks,
Aravind.

I can’t figure this out with just the code that you shared.

Hi:

Thanks for the response. My question is a conceptual one so I didn’t want to add too much code initially, but I edited my question and I hope it is clear now.

Thanks,
Aravind.

Hello,
Really good question! I am pondering the same problem. Do you manage to solve it?
@Aravind_Sundaresan

You would have to pass the weights as a function to the model during initialisation.

class LightningClassifier(pl.LightningModule):
    def __init__(self, model=None, weight_func=None, **kwargs): 
        super().__init__()
        weight = weight_func()
        self.criterion = torch.nn.BCELoss(weight=weight)

Suppose you’re using a config.yaml file, you could do this like this:

config.yaml

model:
  class_path: my.module.LightningClassifier
  init_args:
    model: .....
    weight_func: my.module.my_weight_func

And in my/module.py you can do:

def my_weight_func(x,y,z):
    ...
    return weights