Pass None to Gradient Checkpointing

You can use torch.zeros(0) to “encode” Nones. Or try the below code, this was copy-pasted from checkpoint.py, I removed preserve_rng_state functionality (was not needed for my use case) and added handling of python arguments (scalars at least).

from torch.utils.checkpoint import detach_variable

class CheckpointXFunction(torch.autograd.Function):

	@staticmethod
	def forward(ctx, run_function, *args):
		ctx.run_function = run_function
		
		te_args = []
		py_args = []
		for iarg,arg in enumerate(args):
			if torch.is_tensor(arg):
				te_args.append(arg)
			else:
				py_args.append((iarg,arg))
		
		ctx.save_for_backward(*te_args)
		ctx.py_args = py_args
		with torch.no_grad():
			outputs = run_function(*args)
		return outputs

	@staticmethod
	def backward(ctx, *args):
		if not torch.autograd._is_checkpoint_valid():
			raise RuntimeError("Checkpointing is not compatible with .grad(), please use .backward() if possible")
		inputs = detach_variable(ctx.saved_tensors)
		if ctx.py_args:
			inputs = list(inputs)
			for i,arg in ctx.py_args:
				inputs.insert(i, arg)
			
		with torch.enable_grad():
			outputs = ctx.run_function(*inputs)

		if isinstance(outputs, torch.Tensor):
			outputs = (outputs,)
		
		flt_outputs = []
		flt_args = []
		for ioutput, output in enumerate(outputs):
			if output is not None and output.requires_grad:
				flt_outputs.append(output)
				flt_args.append(args[ioutput])
				
		if flt_outputs:
			torch.autograd.backward(flt_outputs, flt_args)
		grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None for inp in inputs)
		return (None,) + grads

def checkpointx(function, *args):
	return CheckpointXFunction.apply(function, *args)
1 Like