You can use torch.zeros(0) to “encode” Nones. Or try the below code, this was copy-pasted from checkpoint.py, I removed preserve_rng_state functionality (was not needed for my use case) and added handling of python arguments (scalars at least).
from torch.utils.checkpoint import detach_variable
class CheckpointXFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, run_function, *args):
ctx.run_function = run_function
te_args = []
py_args = []
for iarg,arg in enumerate(args):
if torch.is_tensor(arg):
te_args.append(arg)
else:
py_args.append((iarg,arg))
ctx.save_for_backward(*te_args)
ctx.py_args = py_args
with torch.no_grad():
outputs = run_function(*args)
return outputs
@staticmethod
def backward(ctx, *args):
if not torch.autograd._is_checkpoint_valid():
raise RuntimeError("Checkpointing is not compatible with .grad(), please use .backward() if possible")
inputs = detach_variable(ctx.saved_tensors)
if ctx.py_args:
inputs = list(inputs)
for i,arg in ctx.py_args:
inputs.insert(i, arg)
with torch.enable_grad():
outputs = ctx.run_function(*inputs)
if isinstance(outputs, torch.Tensor):
outputs = (outputs,)
flt_outputs = []
flt_args = []
for ioutput, output in enumerate(outputs):
if output is not None and output.requires_grad:
flt_outputs.append(output)
flt_args.append(args[ioutput])
if flt_outputs:
torch.autograd.backward(flt_outputs, flt_args)
grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None for inp in inputs)
return (None,) + grads
def checkpointx(function, *args):
return CheckpointXFunction.apply(function, *args)