Debug save_for_backward memory usage

Context: DenseNet-like model for speech recognition - JasperNet. I have a custom leaky_relu activation function that does save_for_backward tensors that are kept for backward by other parts of the model anyway (specifically, by my custom BatchNorm-derived modules). However, doing this increases max_memory_reserved considerably, as if it kept additional copies of tensors or forgot to decrease refcount after activation backward is done.

How to debug memory usage of autograd engine? Can I trace the refcounts of tensors saved for backward? Is it possible to dump tensors kept in memory together with their refcounts?

I tried hard to come up with a small repro, but unfortunately I could not. So a big repro that uses my full codebase:

git clone https://github.com/vadimkantorov/convasr # unfortunately requires apex (though not used) and librosa
cd convasr
CUDA_VISIBLE_DEVICES=0 python3 benchmark.py --backward --model JasperNetBigInplace
# load+fwd 185.53 msec | bwd 1035.61 msec | cudamem 5645.53 mb

CUDA_VISIBLE_DEVICES=0 python3 benchmark.py --backward --model JasperNetBigInplaceBug 
# load+fwd 186.24 msec | bwd 1029.35 msec | cudamem 3535.80 mb

The buggy model does not save for backward residual branches (and thus does not undo adding them) is controlled by this https://github.com/vadimkantorov/convasr/blob/0d0141f98db650c39723e09fc1b0f183d2ddd9ea/models.py#L225

@ngimel Maybe you have an idea how to debug such memory increases? I’m not sure if it creates copies or that it doesn’t decrease refcount when it should

save_for_backward saves the current tensor plus the “inplace version counter” separately and then at backwards compares the version counter of the tensor whether it still matches the saved version (throwing the infamous a tensor needed for backwards has been modified …).
The thing to probably avoid is some circular dependency at the C++ shared pointer level as that will produce memory leaks, but I’m not sure why it should do that.

To sidetrack version tracking, I do my computations with .data directly and have no mark_dirty in some places (invertible computations are not very well supported for now…)

	class InvertibleResidualInplaceFunction(torch.autograd.function.Function):
		bug = False

		@staticmethod
		def forward(ctx, nonlinearity, x, *residual):
			ctx.nonlinearity = nonlinearity
			assert ctx.nonlinearity and ctx.nonlinearity[0] in ['leaky_relu']
			
			y = x.data
			for r in residual:
				y += r
			y = getattr(F, ctx.nonlinearity[0])(y, *ctx.nonlinearity[1:], inplace = True)
			ctx.num_residual = len(residual)
			
			if ResidualActivation.InvertibleResidualInplaceFunction.bug:
				residual = []
			
			#ctx.mark_dirty(x)
			ctx.save_for_backward(x, *residual)
			return x

		@staticmethod
		def backward(ctx, grad_output):
			x, *residual = ctx.saved_tensors
			y = x.data
			mask = torch.ones_like(grad_output).masked_fill_(x < 0, ctx.nonlinearity[1])
			grad_output *= mask
			y /= mask
			for r in residual:
				y -= r
			return (None, ) + (grad_output,) * (1 + ctx.num_residual)

class InplaceBatchNorm1d(nn.BatchNorm1d):
	def forward(self, input):
		return InplaceBatchNorm1d.Function.apply(input, self.weight, self.bias, self.running_mean, self.running_var, self.eps, self.momentum, self.training)

	class Function(torch.autograd.function.Function):
		@staticmethod
		def forward(ctx, input, weight, bias, running_mean, running_var, eps, momentum, training):
			mean, var = torch.batch_norm_update_stats(input, running_mean, running_var, momentum) if training else (running_mean, running_var)
			invstd = (var + eps).rsqrt_()
			output = torch.batch_norm_elemt(input, weight, bias, mean, invstd, 0, out = input)
			ctx.training = training
			ctx.save_for_backward(input, weight, bias, mean, invstd)
			ctx.mark_dirty(input)
			return input

		@staticmethod
		def backward(ctx, grad_output):
			assert ctx.training
			saved_output, weight, bias, mean, invstd = ctx.saved_tensors
			saved_input = torch.batch_norm_elemt(saved_output, invstd.reciprocal(), mean, bias, weight.reciprocal(), 0, out = saved_output)
			sum_dy, sum_dy_xmu, grad_weight, grad_bias = torch.batch_norm_backward_reduce(grad_output, saved_input, mean, invstd, weight, *ctx.needs_input_grad[:3])
			divisor = saved_input.numel() // saved_input.size(1)
			mean_dy = sum_dy.div_(divisor)
			mean_dy_xmu = sum_dy_xmu.div_(divisor)
			grad_input = torch.batch_norm_backward_elemt(grad_output, saved_input, mean, invstd, weight, mean_dy, mean_dy_xmu)
			return grad_input, grad_weight, grad_bias, None, None, None, None, None

# later on
residual = [...]
x = ...
residual_inputs = [bn(conv(r)) for conv, bn, r in zip(self.conv_residual, self.bn_residual, residual)]
x = self.activation(bn(conv(x)), residual = residual_inputs)
# activation applies the activation function, and bn is instance of this inplace batchnorm

Also, even when not saving for backward in ResidualActivation, the max_memory_reserved is not decreasing. Despite that BatchNorm is done inplace, and activation is inplace as well…

@tom I found a bug in my benchmarking. So this is no longer relevant :frowning: I still updated the caching allocator issue for the usecase of debugging save_for_backward.