Custom backward with staticmethod

I have the following structure in my code to use custom backward:

class myclass(Function):

    @staticmethod
    def forward(ctx, inputs, model_cpu):
        
        ubatches =  embedding(inputs)
        sequence_output_list, output_cpu_list = bert(ubatches)
        pooled_output_list = pooler(sequence_output_list)  
        loss_list = classifier(pooled_output_list)

        ctx.args = args
        ctx.model_cpu = model_cpu

        return torch.stack(loss_list,dim=0).clone()#, output_cpu_list, bert_output  # (loss), logits, (hidden_states), (attentions)

    @staticmethod
    def backward (ctx, grad_hidden_states):
        print ('backward')
        model_cpu = ctx.model_cpu
        args = ctx.args

        model_cpu, claasifier_input_grads = classifier.backward(args, model_cpu, grad_hidden_states)
        model_cpu, BertModel_input_grads = bert.backward (args, model_cpu, claasifier_input_grads)
        model_cpu, Embedding_input_grads = embedding.backward (args, model_cpu, BertModel_input_grads)

        return model_cpu, None



class myclassRun(BertPreTrainedModel):

    def __init__(self, config):
        super(BertForSequenceClassificationL2LRun, self).__init__(config)
        #self.seq_classifier = BertForSequenceClassificationL2L_new (config)
        self.embedding = BertEmbedding(config)
        self.bert = BertModel(config)
        self.pooler = BertPooler(config)
        self.bertclassifier = BertClassifier(config)

    def forward (self, inputs, model_cpu):

        outputs = myclass.apply (inputs, model_cpu)

        print (outputs.grad_fn)
        print (outputs)
        return outputs

and each class defined in init has the following structure:

class BertEmbedding(BertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
   
    def forward(self, inputs):    

    def backward(self, ):
        

In my code the custom backward is not called. Any idea how I can debug or any better way to restructure? My goal is to have loss.backward () as my final backward call.

Hi,

Which backward() is not called? The one from the Function?
If your BertEmbedding is an nn.Module, then it is expected that it won’t be called.

when I call loss.backward(), I expect to go through the “myclassRun”'s backward which supposed to call “myclass” backward function which is in staticmethod but this staticmethod backward is not called. In that function, I call the custom backward for BertEmbedding.

And what are the prints in your Module’s forward printing?
If you do outputs.sum().backward() there, it should call your custom backward just fine right?