Overwriting Intermediate Variables and Memory Usage

Does overwriting the value of existing variables free memory for possible garbage collection?

Under torch.no_grad() I assume the tensors’ memory is free for collection, as a computation graph is not stored. But presumably if tracking gradients each intermediate tensor is stored in the computation graph for the backwards pass, regardless of whether you overwrite the references to them?

For example, where x is a tensor requiring gradients and {f_i} differentiable operations, we can store the intermediate values,

x = f(x)
x1 = f1(x)
x2 = f2(x1)

or we can overwrite the intermediate values.

x = f(x)
x = f1(x)
x = f2(x)

Overwriting intermediates is generally preferable, as not all intermediates are saved for backward
e.g. x.sin().exp(), the result of sin is not saved. Also consider cases when the op is linear, or when activation checkpoint is used.

1 Like

@soulitzer @lerner just for fun, i tried checking the memory before and after. Memory of the process did not change … maybe because the stack memory which contains references is still containing the same amount of references

import torch
from torch import nn
import gc
import sys
import psutil
import ctypes

class SampleModel(torch.nn.Module):
    def __init__(self):
        super(SampleModel, self).__init__()
        self.model1 = nn.Linear(10, 5)
        self.model2 = nn.Linear(10, 5)

    def forward(self, x):
        return self.model1(x)

model = SampleModel()
with torch.no_grad():
    model.to('cpu')
    print(id(list(model.model1.parameters())[0]))
    print(id(list(model.model1.parameters())[1]))
    print(id(list(model.model2.parameters())[0]))
    old_weight_address = id(list(model.model2.parameters())[0])
    old_bias_address = id(list(model.model2.parameters())[1])
    print(id(list(model.model2.parameters())[1]))
    process = psutil.Process()
    print("Memory used {}".format(process.memory_info()))
    #list(model.model2.parameters())[0] = list(model.model1.parameters())[0]
    #list(model.model2.parameters())[1] = list(model.model1.parameters())[1]
    model.model2.weight = model.model1.weight
    model.model2.bias = model.model1.bias
    print(id(list(model.model1.parameters())[0]))
    print(id(list(model.model1.parameters())[1]))
    print(id(list(model.model2.parameters())[0]))
    print(id(list(model.model2.parameters())[1]))
    gc.collect()
    print("Memory used {}".format(process.memory_info()))
1 Like