Redundant context in forward pass when some input tensors do not require grad

In the following example, I see that the memory consumption is very high, which clearly shows that the forward pass is saving all the tensors that are used in the multiplications, even though only one of the tensors requires grad.

b = torch.randn((gs,gs), dtype=torch.float32)
a = torch.randn((gs,gs), dtype=torch.float32)
b_orig = b

start_mem_usage = memory_profiler.memory_usage(-1,0.1)

for i in range(1000):
    b = (b+1).mm(a)

l = torch.sum(b)
end_mem_usage = memory_profiler.memory_usage(-1,0.1)
print ("mem usage",end_mem_usage[0] - start_mem_usage[0])

this code will need ~285 MB, which means that the forward pass is saving 1000 tensors of 256x256 floats. Since ‘a’ doesn’t require gradient, I’m having a hard time understanding why the context needs more than the value of ‘a’, which never changes.

So I went ahead and implemented my own matmul for constant multiplier, and this time it doesn’t consume almost any memory. My question is why doesn’t torch understand that by itself, according to the grad_required attribute.

class myMatMulConstant(torch.autograd.Function):
    def forward(ctx,input,const_mat):

    def backward(ctx,grad_output):
        const_mat = ctx.saved_tensors[0]
        grad_input =
        return  grad_input, None

my_mm_const = myMatMulConstant.apply

for i in range(1000):
    b = my_mm_const(b+1,a)

The above implementation doesn’t consume 1000 tensors in memory, which makes sense.