I’m wondering if list of tensors can backward in custom autograd function? Below is my sample code.
class ReversibleFunction(Function):
@staticmethod
def forward(
ctx: FunctionCtx,
x, blocks, reverse, layer_state_flags: List[bool],
) -> Tuple[Tensor, List[Tensor]]:
# layer_state_flags: indicate the outputs from
# which layers are used for intermediate loss calculation
ctx.layer_state_flags = layer_state_flags
blocks = blocks[::-1] if reverse else blocks
states = [x] if layer_state_flags[0] else []
for i, block in enumerate(blocks):
x = block(x)
if layer_state_flags[i + 1]:
states.append(x)
ctx.blocks = blocks
ctx.save_for_backward(x.detach())
return x, states
@staticmethod
@once_differentiable
def backward(
ctx: FunctionCtx, dy: Tensor, dys: List[Tensor]
):
# dys: list of grads from the intermediate loss of each selected layer
y = ctx.saved_tensors
layer_state_flags = ctx.layer_state_flags[::-1]
for i, block in enumerate(ctx.blocks[::-1]):
if layer_state_flags[i]:
dy += dys[i]
y, dy = block.backward(y, dy)
return dy, None, None, None, None, None
Since my blocks are reversible, so I don’t need to store states
intermediate outputs from each selected layer. However, I want to use those intermediate outputs states
to calculate intermediate losses and backward then. I wonder if grads input dys
for backward will also come in the form of a list