Pass None to Gradient Checkpointing


I am trying to convert T5 from Google model using HuggingFace library to use gradient checkpointing. However, some variables might be None which lead to an error from checkpointing. Is there any solutions to this issue.


Error message :

TypeError: CheckpointFunctionBackward.forward: expected Variable (got NoneType) for return value 1

Any help here will be highly appreciated .

You can use torch.zeros(0) to “encode” Nones. Or try the below code, this was copy-pasted from, I removed preserve_rng_state functionality (was not needed for my use case) and added handling of python arguments (scalars at least).

from torch.utils.checkpoint import detach_variable

class CheckpointXFunction(torch.autograd.Function):

	def forward(ctx, run_function, *args):
		ctx.run_function = run_function
		te_args = []
		py_args = []
		for iarg,arg in enumerate(args):
			if torch.is_tensor(arg):
		ctx.py_args = py_args
		with torch.no_grad():
			outputs = run_function(*args)
		return outputs

	def backward(ctx, *args):
		if not torch.autograd._is_checkpoint_valid():
			raise RuntimeError("Checkpointing is not compatible with .grad(), please use .backward() if possible")
		inputs = detach_variable(ctx.saved_tensors)
		if ctx.py_args:
			inputs = list(inputs)
			for i,arg in ctx.py_args:
				inputs.insert(i, arg)
		with torch.enable_grad():
			outputs = ctx.run_function(*inputs)

		if isinstance(outputs, torch.Tensor):
			outputs = (outputs,)
		flt_outputs = []
		flt_args = []
		for ioutput, output in enumerate(outputs):
			if output is not None and output.requires_grad:
		if flt_outputs:
			torch.autograd.backward(flt_outputs, flt_args)
		grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None for inp in inputs)
		return (None,) + grads

def checkpointx(function, *args):
	return CheckpointXFunction.apply(function, *args)
