I changed LossWrapper block to
class OutputLossWrapper(LossWrapper):
def forward(self, x, target):
return self.loss_fn(self.module(x), target)
wrapper = OutputLossWrapper(model, cross_entropy)
.
the original version from github example I mentioned is
class OutputLossWrapper(LossWrapper):
def __init__(self, module, loss_fn):
super().__init__(module, loss_fn)
def forward(self, input, target):
output = self.module(input)
return output, self.loss_fn(output, target)
wrapper = OutputLossWrapper(model, cross_entropy)
.
.
.
and now I got stage_backward like this
def forward(self, x, target):
submod_0 = self.submod_0(x)
submod_1 = self.submod_1(submod_0)
submod_2 = self.submod_2(submod_1, target)
stage_backward = pippy_backward_stage_backward(stage_output = (submod_2,), output_grads = (None,), input_values = [submod_1, target], outputs_with_grads_idxs = [0], stage_info = 'stage_backward for stage %submod_2 : [#users=2] = call_module[target=submod_2](args = (%submod_1, %target), kwargs = {})'); target = None
getitem = stage_backward[0]
getitem_1 = stage_backward[1]; stage_backward = None
getitem_2 = getitem[0]
getitem_3 = getitem[1]; getitem = None
stage_backward_1 = pippy_backward_stage_backward(stage_output = (submod_1,), output_grads = (getitem_2,), input_values = [submod_0], outputs_with_grads_idxs = [0], stage_info = 'stage_backward_1 for stage %submod_1 : [#users=3] = call_module[target=submod_1](args = (%submod_0,), kwargs = {})'); submod_1 = getitem_2 = None
getitem_4 = stage_backward_1[0]
getitem_5 = stage_backward_1[1]; stage_backward_1 = None
getitem_6 = getitem_4[0]; getitem_4 = None
stage_backward_2 = pippy_backward_stage_backward(stage_output = (submod_0,), output_grads = (getitem_6,), input_values = [x], outputs_with_grads_idxs = [0], stage_info = 'stage_backward_2 for stage %submod_0 : [#users=3] = call_module[target=submod_0](args = (%x,), kwargs = {})'); submod_0 = getitem_6 = x = None
getitem_7 = stage_backward_2[0]
getitem_8 = stage_backward_2[1]; stage_backward_2 = None
getitem_9 = getitem_7[0]
sync_barrier = pippy_backward_sync_barrier(submod_2, [getitem_1, getitem_5, getitem_8], getitem_7); submod_2 = getitem_1 = getitem_5 = getitem_8 = getitem_7 = None
return sync_barrier
but it doesn’t seem to be trained well.
and I printed loss value(=pipe_driver(x, target)) , It is greater than 1.
Why doesn’t the original OutputLossWrapper make backward stage?