How to free backward input args when I use iree as the compilation backend?

Hello, everyone. I have a little complicated question to discuss, please keep patient and I will describe it in detail.
My purpose is to construct a pytorch + iree-turbine + iree system, aiming to realize the model train in my custom backend. You can regard iree-turbine as a bridge to use iree python api easily. The iree backend can work as a compiler and a runner.
I have make the system work, but I have a question in memory leak. This is why I am here.

As I know, in torch.compile and inductor as backend, when running a model train, pytorch saves the fw outputs for the bw, bw will wrap the loss and fw’s output as its own inputs. The function I print look like below:

  • indutor fw functon, please ignore the “print”.
def call(args):
    primals_1, primals_2, primals_3, primals_4, primals_5 = args
    args.clear()
    assert_size_stride(primals_1, (128, 128), (128, 1))
    assert_size_stride(primals_2, (128, 128), (128, 1))
    assert_size_stride(primals_3, (128, 128), (128, 1))
    assert_size_stride(primals_4, (128, 128), (128, 1))
    assert_size_stride(primals_5, (128, 128), (128, 1))
    with torch.cuda._DeviceGuard(0):
        torch.cuda.set_device(0)
        buf4 = empty_strided_cuda((128, 128), (128, 1), torch.bool)
        buf3 = empty_strided_cuda((128, 128), (128, 1), torch.float32)
        buf1 = empty_strided_cuda((2, ), (1, ), torch.float32)
        # Source Nodes: [add, mean, mul, mul_1, out, out_1, pow_1, sub], Original ATen: [aten.add, aten.mean, aten.mul, aten.pow, aten.relu, aten.sub, aten.threshold_backward]
        stream0 = get_raw_stream(0)
        triton_red_fused_add_mean_mul_pow_relu_sub_threshold_backward_0.run(primals_1, primals_4, primals_2, primals_3, primals_5, buf4, buf3, buf1, 2, 8192, grid=grid(2), stream=stream0)
        import sys  # for refcount debug
        print("Before del primals_1, refcount =", sys.getrefcount(primals_1))
        del primals_1
        print("After del primals_1")
        print("Before del primals_2, refcount =", sys.getrefcount(primals_2))
        del primals_2
        print("After del primals_2")
        print("Before del primals_3, refcount =", sys.getrefcount(primals_3))
        del primals_3
        print("After del primals_3")
        print("Before del primals_5, refcount =", sys.getrefcount(primals_5))
        del primals_5
        print("After del primals_5")
        buf2 = empty_strided_cuda((), (), torch.float32)
        buf5 = buf2; del buf2  # reuse
        # Source Nodes: [mean, pow_1], Original ATen: [aten.mean, aten.pow]
        triton_per_fused_mean_pow_1.run(buf5, buf1, 1, 2, grid=grid(1), stream=stream0)
        del buf1
    return (buf5, primals_4, buf3, buf4, )
  • indutor bw functon.
def call(args):
    primals_4, mul_2, le, tangents_1 = args
    args.clear()
    assert_size_stride(primals_4, (128, 128), (128, 1))
    assert_size_stride(mul_2, (128, 128), (128, 1))
    assert_size_stride(le, (128, 128), (128, 1))
    assert_size_stride(tangents_1, (), ())
    with torch.cuda._DeviceGuard(0):
        torch.cuda.set_device(0)
        buf0 = empty_strided_cuda((128, 128), (128, 1), torch.float32)
        buf1 = empty_strided_cuda((128, 128), (128, 1), torch.float32)
        # Source Nodes: [], Original ATen: [aten.div, aten.mul, aten.threshold_backward]
        stream0 = get_raw_stream(0)
        triton_poi_fused_div_mul_threshold_backward_0.run(le, tangents_1, mul_2, primals_4, buf0, buf1, 16384, grid=grid(16384), stream=stream0)
        import sys  # for refcount debug
        print("Before del le, refcount =", sys.getrefcount(le))
        del le
        print("After del le")
        print("Before del mul_2, refcount =", sys.getrefcount(mul_2))
        del mul_2
        print("After del mul_2")
        print("Before del primals_4, refcount =", sys.getrefcount(primals_4))
        del primals_4
        print("After del primals_4")
        print("Before del tangents_1, refcount =", sys.getrefcount(tangents_1))
        del tangents_1
        print("After del tangents_1")
    return (buf1, buf1, buf0, None, None, )

Firstly, I notice the outputs of forward: buf5, buf3, buf4, these three are allocated in this function, and they wiil passed as inputs in bw. In the bw function, inputs will be “del” finally. So, except from primals_4, other inputs of bw all will be free. I know the “del” is indeed decrease the ref count, but not to free the tensor memory. Is my understand is right?
Secondly, if I use my own iree backend, the bw function’s call will be looks like below:

 def __call__(self, *inputs):
        arg_list = VmVariantList(len(inputs))
        ret_list = VmVariantList(
            1
        )  # TODO: Get the number of results from the descriptor.

        # Move inputs to the device and add to arguments.
        self._inputs_to_device(inputs, arg_list)
        # TODO: Append semaphores for async execution.

        # Invoke.
        self.vm_context.invoke(self.entry_function, arg_list, ret_list)
       
        print("---free bw args---\n")
        del arg_list
        inputs = list(inputs)
        for i in range(len(inputs)):
            print(f"refcount before inputs[{i}] = None:", sys.getrefcount(inputs[i]))
            inputs[i] = None
            print(f"refcount after  inputs[{i}] = None:", "cannot measure" if inputs[i] is None else sys.getrefcount(inputs[i]))
        del inputs
        import gc
        gc.collect()
        print("------------------\n")
        

        return self._returns_to_user(ret_list)

The “free bw args” part is the code added to free the bw’s input tensors. I try to model after
the indutor bw function to del the inputs, because iree runtime don’t whether the tensor should be free or not. For example, primals_4 should not be free, because it is the forward inputs and it is alsi the weights of the model.
After adding “free bw args” part, I can’t see any memory free or recycle action.
So I want to ask, why the args in bw function can’t be free? I know the free is controlled by ref count of instrusive_otr. But I really not familiar with the logic of memory management.
And, what should I do if I really want to free the input tensor(forward produced and can be del) of bw?
Really hopeful for a help, thanks a lot.