Pass None to Gradient Checkpointing

Hello,

I am trying to convert T5 from Google model using HuggingFace library to use gradient checkpointing. However, some variables might be None which lead to an error from checkpointing. Is there any solutions to this issue.

Code:

Error message :

TypeError: CheckpointFunctionBackward.forward: expected Variable (got NoneType) for return value 1

1 Like

Any help here will be highly appreciated .

You can use torch.zeros(0) to “encode” Nones. Or try the below code, this was copy-pasted from checkpoint.py, I removed preserve_rng_state functionality (was not needed for my use case) and added handling of python arguments (scalars at least).

from torch.utils.checkpoint import detach_variable

class CheckpointXFunction(torch.autograd.Function):

	@staticmethod
	def forward(ctx, run_function, *args):
		ctx.run_function = run_function
		
		te_args = []
		py_args = []
		for iarg,arg in enumerate(args):
			if torch.is_tensor(arg):
				te_args.append(arg)
			else:
				py_args.append((iarg,arg))
		
		ctx.save_for_backward(*te_args)
		ctx.py_args = py_args
		with torch.no_grad():
			outputs = run_function(*args)
		return outputs

	@staticmethod
	def backward(ctx, *args):
		if not torch.autograd._is_checkpoint_valid():
			raise RuntimeError("Checkpointing is not compatible with .grad(), please use .backward() if possible")
		inputs = detach_variable(ctx.saved_tensors)
		if ctx.py_args:
			inputs = list(inputs)
			for i,arg in ctx.py_args:
				inputs.insert(i, arg)
			
		with torch.enable_grad():
			outputs = ctx.run_function(*inputs)

		if isinstance(outputs, torch.Tensor):
			outputs = (outputs,)
		
		flt_outputs = []
		flt_args = []
		for ioutput, output in enumerate(outputs):
			if output is not None and output.requires_grad:
				flt_outputs.append(output)
				flt_args.append(args[ioutput])
				
		if flt_outputs:
			torch.autograd.backward(flt_outputs, flt_args)
		grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None for inp in inputs)
		return (None,) + grads

def checkpointx(function, *args):
	return CheckpointXFunction.apply(function, *args)
1 Like