Hello, everyone. I have a little complicated question to discuss, please keep patient and I will describe it in detail.
My purpose is to construct a pytorch + iree-turbine + iree system, aiming to realize the model train in my custom backend. You can regard iree-turbine as a bridge to use iree python api easily. The iree backend can work as a compiler and a runner.
I have make the system work, but I have a question in memory leak. This is why I am here.
As I know, in torch.compile and inductor as backend, when running a model train, pytorch saves the fw outputs for the bw, bw will wrap the loss and fw’s output as its own inputs. The function I print look like below:
- indutor fw functon, please ignore the “print”.
def call(args):
primals_1, primals_2, primals_3, primals_4, primals_5 = args
args.clear()
assert_size_stride(primals_1, (128, 128), (128, 1))
assert_size_stride(primals_2, (128, 128), (128, 1))
assert_size_stride(primals_3, (128, 128), (128, 1))
assert_size_stride(primals_4, (128, 128), (128, 1))
assert_size_stride(primals_5, (128, 128), (128, 1))
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
buf4 = empty_strided_cuda((128, 128), (128, 1), torch.bool)
buf3 = empty_strided_cuda((128, 128), (128, 1), torch.float32)
buf1 = empty_strided_cuda((2, ), (1, ), torch.float32)
# Source Nodes: [add, mean, mul, mul_1, out, out_1, pow_1, sub], Original ATen: [aten.add, aten.mean, aten.mul, aten.pow, aten.relu, aten.sub, aten.threshold_backward]
stream0 = get_raw_stream(0)
triton_red_fused_add_mean_mul_pow_relu_sub_threshold_backward_0.run(primals_1, primals_4, primals_2, primals_3, primals_5, buf4, buf3, buf1, 2, 8192, grid=grid(2), stream=stream0)
import sys # for refcount debug
print("Before del primals_1, refcount =", sys.getrefcount(primals_1))
del primals_1
print("After del primals_1")
print("Before del primals_2, refcount =", sys.getrefcount(primals_2))
del primals_2
print("After del primals_2")
print("Before del primals_3, refcount =", sys.getrefcount(primals_3))
del primals_3
print("After del primals_3")
print("Before del primals_5, refcount =", sys.getrefcount(primals_5))
del primals_5
print("After del primals_5")
buf2 = empty_strided_cuda((), (), torch.float32)
buf5 = buf2; del buf2 # reuse
# Source Nodes: [mean, pow_1], Original ATen: [aten.mean, aten.pow]
triton_per_fused_mean_pow_1.run(buf5, buf1, 1, 2, grid=grid(1), stream=stream0)
del buf1
return (buf5, primals_4, buf3, buf4, )
- indutor bw functon.
def call(args):
primals_4, mul_2, le, tangents_1 = args
args.clear()
assert_size_stride(primals_4, (128, 128), (128, 1))
assert_size_stride(mul_2, (128, 128), (128, 1))
assert_size_stride(le, (128, 128), (128, 1))
assert_size_stride(tangents_1, (), ())
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
buf0 = empty_strided_cuda((128, 128), (128, 1), torch.float32)
buf1 = empty_strided_cuda((128, 128), (128, 1), torch.float32)
# Source Nodes: [], Original ATen: [aten.div, aten.mul, aten.threshold_backward]
stream0 = get_raw_stream(0)
triton_poi_fused_div_mul_threshold_backward_0.run(le, tangents_1, mul_2, primals_4, buf0, buf1, 16384, grid=grid(16384), stream=stream0)
import sys # for refcount debug
print("Before del le, refcount =", sys.getrefcount(le))
del le
print("After del le")
print("Before del mul_2, refcount =", sys.getrefcount(mul_2))
del mul_2
print("After del mul_2")
print("Before del primals_4, refcount =", sys.getrefcount(primals_4))
del primals_4
print("After del primals_4")
print("Before del tangents_1, refcount =", sys.getrefcount(tangents_1))
del tangents_1
print("After del tangents_1")
return (buf1, buf1, buf0, None, None, )
Firstly, I notice the outputs of forward: buf5, buf3, buf4, these three are allocated in this function, and they wiil passed as inputs in bw. In the bw function, inputs will be “del” finally. So, except from primals_4, other inputs of bw all will be free. I know the “del” is indeed decrease the ref count, but not to free the tensor memory. Is my understand is right?
Secondly, if I use my own iree backend, the bw function’s call will be looks like below:
def __call__(self, *inputs):
arg_list = VmVariantList(len(inputs))
ret_list = VmVariantList(
1
) # TODO: Get the number of results from the descriptor.
# Move inputs to the device and add to arguments.
self._inputs_to_device(inputs, arg_list)
# TODO: Append semaphores for async execution.
# Invoke.
self.vm_context.invoke(self.entry_function, arg_list, ret_list)
print("---free bw args---\n")
del arg_list
inputs = list(inputs)
for i in range(len(inputs)):
print(f"refcount before inputs[{i}] = None:", sys.getrefcount(inputs[i]))
inputs[i] = None
print(f"refcount after inputs[{i}] = None:", "cannot measure" if inputs[i] is None else sys.getrefcount(inputs[i]))
del inputs
import gc
gc.collect()
print("------------------\n")
return self._returns_to_user(ret_list)
The “free bw args” part is the code added to free the bw’s input tensors. I try to model after
the indutor bw function to del the inputs, because iree runtime don’t whether the tensor should be free or not. For example, primals_4 should not be free, because it is the forward inputs and it is alsi the weights of the model.
After adding “free bw args” part, I can’t see any memory free or recycle action.
So I want to ask, why the args in bw function can’t be free? I know the free is controlled by ref count of instrusive_otr. But I really not familiar with the logic of memory management.
And, what should I do if I really want to free the input tensor(forward produced and can be del) of bw?
Really hopeful for a help, thanks a lot.