I have the following structure in my code to use custom backward:
class myclass(Function):
@staticmethod
def forward(ctx, inputs, model_cpu):
ubatches = embedding(inputs)
sequence_output_list, output_cpu_list = bert(ubatches)
pooled_output_list = pooler(sequence_output_list)
loss_list = classifier(pooled_output_list)
ctx.args = args
ctx.model_cpu = model_cpu
return torch.stack(loss_list,dim=0).clone()#, output_cpu_list, bert_output # (loss), logits, (hidden_states), (attentions)
@staticmethod
def backward (ctx, grad_hidden_states):
print ('backward')
model_cpu = ctx.model_cpu
args = ctx.args
model_cpu, claasifier_input_grads = classifier.backward(args, model_cpu, grad_hidden_states)
model_cpu, BertModel_input_grads = bert.backward (args, model_cpu, claasifier_input_grads)
model_cpu, Embedding_input_grads = embedding.backward (args, model_cpu, BertModel_input_grads)
return model_cpu, None
class myclassRun(BertPreTrainedModel):
def __init__(self, config):
super(BertForSequenceClassificationL2LRun, self).__init__(config)
#self.seq_classifier = BertForSequenceClassificationL2L_new (config)
self.embedding = BertEmbedding(config)
self.bert = BertModel(config)
self.pooler = BertPooler(config)
self.bertclassifier = BertClassifier(config)
def forward (self, inputs, model_cpu):
outputs = myclass.apply (inputs, model_cpu)
print (outputs.grad_fn)
print (outputs)
return outputs
and each class defined in init has the following structure:
class BertEmbedding(BertPreTrainedModel):
def __init__(self, config):
super().__init__(config)
def forward(self, inputs):
def backward(self, ):
In my code the custom backward is not called. Any idea how I can debug or any better way to restructure? My goal is to have loss.backward () as my final backward call.