Is it okay to create a Loss function within the forward method?

Hi everyone, I have a model that uses HuggingFace’s Bert implementation as a submodule. (I share the relevant parts of the model below.) My question is about the proper place for loss function initialization. I’ll use the nn.CrossEntropyLoss() as a loss function of the model. Is it okay if I use it as below? Or should I initialize it as a field of the model within the __init__ function (like self.loss_function = nn.CrossEntropyLoss()) and then use it inside the forward function? I asked this because when I initialize it within the forward function, I use a new loss function at every forward call and I don’t know if this breaks something (maybe loss function needs to remember something ? )

class MyModel(torch.nn.Module):

    def __init__(self,bert_dim,n_labels,pretrained_weights,dp=0.5):
        self.num_labels = n_labels
        self.bert1 = BertModel.from_pretrained(pretrained_weights)
        self.bert2 = BertModel.from_pretrained(pretrained_weights)
        self.dropout = nn.Dropout(dp)
        self.classifier = nn.Linear(2*bert_dim, n_labels)
    def forward(self, input_ids=None, attention_mask=None, token_type_ids=None,

        output1 = self.bert1(input_ids,

        output2 = self.bert2(c_input_ids,

         logits = # get logits by using output1 and output2 here for sake of simplicity. 
         loss_fct = nn.CrossEntropyLoss()
         loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
        return (logits),loss


I could not find any information about buffering data from Loss (maybe necessary in particular loss). But the ideal form is to use initializations within __init__ method and usage in forward, because initializing a object has its own overhead problems when for every single run in train or test loop, the object needs to be initialized (may cause memory issues too, because previously initialized objects are no longer useful and garbage collector should handle it I think).

So I think as you used this convention for nn.Dropout or the other parts, I think you should do it for Loss too.