Assignment to split tensor causes memory leak

I want to train a network for smoothing a point cloud.
It is applied independently for each point and has the surrounding points as input.
The function below performs the actual smoothing by splitting the points (x) and the neighborhoods (y) into batches and then calling the model.

def do_smoothing(x, y, model, batch_size):
    torch.no_grad()   # Make sure that no gradient memory is allocated
    print('Used Memory: {:3f}'.format(torch.cuda.memory_allocated('cuda') * 1e-9 ))
    xbatches = list(torch.split(x, batch_size))
    ybatches = list(torch.split(y, batch_size))
    assert(len(xbatches) == len(ybatches))
    print('Used Memory: {:3f}'.format(torch.cuda.memory_allocated('cuda') * 1e-9 ))
    for i in range(len(xbatches)):
        print("Batch {} [Memory: {:3f}]".format(i + 1, torch.cuda.memory_allocated('cuda') * 1e-9))
        tmp = self.forward(_xbatches[i])
        xbatches[i] += tmp
        del tmp
        print("Batch {} [Memory: {:3f}]".format(i + 1, torch.cuda.memory_allocated('cuda') * 1e-9))
    return x

Unfortunately, the memory consumption is very high leading to an out of memory error. The console output is:

Used Memory: 0.256431
Used Memory: 0.256431
Prior batch 1 [Memory: 0.256431]
Post batch 1 [Memory: 2.036074]
Prior batch 2 [Memory: 2.036074]
 -> Out of memory error

If I remove the line xbatches[i] += tmp the allocated memory is not changing (as expected).
If I also remove the line del tmp on the other hand the code once again allocates huge amounts of GPU memory.

I assume that torch.split creates views onto the tensor and therefore the update should not use additional memory.
Did I miss anything or is this unintended behavior?

I am using pytorch 1.4.0 with CUDA 10.1 on a Windows 10 machine.

Thanks in advance for any tips.

Could you try to use torch.no_grad() in a with statement:

with torch.no_grad():
    xbatches = ...

If I’m not mistaken, your current approach shouldn’t change the gradient behavior.

Thank you very much, this indeed fixes the problem.