I’m wondering if list of tensors can backward in custom autograd function? Below is my sample code.
class ReversibleFunction(Function): @staticmethod def forward( ctx: FunctionCtx, x, blocks, reverse, layer_state_flags: List[bool], ) -> Tuple[Tensor, List[Tensor]]: # layer_state_flags: indicate the outputs from # which layers are used for intermediate loss calculation ctx.layer_state_flags = layer_state_flags blocks = blocks[::-1] if reverse else blocks states = [x] if layer_state_flags else  for i, block in enumerate(blocks): x = block(x) if layer_state_flags[i + 1]: states.append(x) ctx.blocks = blocks ctx.save_for_backward(x.detach()) return x, states @staticmethod @once_differentiable def backward( ctx: FunctionCtx, dy: Tensor, dys: List[Tensor] ): # dys: list of grads from the intermediate loss of each selected layer y = ctx.saved_tensors layer_state_flags = ctx.layer_state_flags[::-1] for i, block in enumerate(ctx.blocks[::-1]): if layer_state_flags[i]: dy += dys[i] y, dy = block.backward(y, dy) return dy, None, None, None, None, None
Since my blocks are reversible, so I don’t need to store
states intermediate outputs from each selected layer. However, I want to use those intermediate outputs
states to calculate intermediate losses and backward then. I wonder if grads input
dys for backward will also come in the form of a list