Gradient Checkpointing returning values

I have a checkpoint callback function (i.e, custom_dec) that returns a Tensor, and a dictionary. But it seems like this function does not return dictionaries (or other data types), but only tensors. What is the workaround to this, as the module that I want to checkpoint is returning a tensor, plus a data type as dictionary:

    def custom_dec(self, module):
        def custom_forward(*inputs):
            output = module(inputs[0], inputs[1],
                            encoder_attn_mask=inputs[2],
                            decoder_padding_mask=inputs[3],
                            layer_state=inputs[4],
                            causal_mask=inputs[5],
                            output_attentions=inputs[6],
                            )
            # output[2] is a python dictionary
            return output[0], output[2]

The following is the checkpoint call:

x, layer_past = \
                checkpoint.checkpoint(
                    self.custom_dec(decoder_layer),
                    x,
                    encoder_hidden_states,
                    encoder_padding_mask,
                    decoder_padding_mask,
                    layer_state,
                    decoder_causal_mask,
                    output_attentions,
                )

The error:

TypeError: CheckpointFunctionBackward.forward: expected Variable (got dictionary) for return value 1

Hi,

Assuming this is not a differentiable output, you can return it by side effect:

    def custom_dec(self, module, out_dict):
        def custom_forward(*inputs):
            output = module(inputs[0], inputs[1],
                            encoder_attn_mask=inputs[2],
                            decoder_padding_mask=inputs[3],
                            layer_state=inputs[4],
                            causal_mask=inputs[5],
                            output_attentions=inputs[6],
                            )
            # output[2] is a python dictionary
            out_dict[0] = output[2]
            return output[0]

out_dict = []
x = checkpoint.checkpoint(
                    self.custom_dec(decoder_layer, out_dict),
                    x,
                    encoder_hidden_states,
                    encoder_padding_mask,
                    decoder_padding_mask,
                    layer_state,
                    decoder_causal_mask,
                    output_attentions,
                )
layer_past = out_dict[0]