Multiplication by fixed tensor through loop without memory issue

Hello,

I have a simple module that given a 2D tensor, multiplies it by another 2D tensor M fixed in advance (element wise), then computes a weighted sum. I use a loop to multiply the weights to the elements of the sum. There is probably a better way to do so, but this is only a simple example of my more complicated problem which really requires this kind of loop.

import torch
import torch.nn.functional as F

class NetSimple(torch.nn.Module):
    def __init__(self,M):
        super(NetSimple, self).__init__()
        self.I= 30
        with torch.no_grad():
            self.M = torch.nn.Parameter(M)
        self.weight = torch.nn.Parameter(torch.randn(self.I))

    def forward(self, x):
        # x.shape[1] should be 1, grayscale images.
        prod = torch.zeros(x.shape[0],1,x.shape[2],x.shape[3]).type_as(x)
        for ii in range(self.I):
                prod += self.weight[ii]*(self.M*x)
        x = F.relu(prod)
        return x

N=1024 # Input size
M=torch.randn(N,N)
model = NetSimple(M)
model = model.cuda()

inp = torch.randn(10,1,N,N)
inp = inp.cuda()
out = model(inp)

Here, my problem is that the memory required is very important. It seems that a new copy of M is made for each elements in the loop. It takes about 1,5Go in GPU.

My question is, is there any nice memory way to multiply several times a tensor by the same fixed tensor without seeing the needed memory being scaled with the number of multiplication made?

Thank you in advance.

This creates a rather large computational graph indeed.
self.M*x though could be computed a single time and then reused in the loop, that will reduce it quire a bit !

Also even though your use case is more complex, remember that using .expand() and then doing one single multiplucation (possibly with a sum as well) will be much more efficient in terms of memory and runtime. expand does not use any extra memory and the computational graph will be substantially smaller and thus the memory required to store the intermediary results will be reduced as well.

Thanks a lot for your answer.
Compute things once before the loop helps a lot for the memory.
I didn’t want to add temporary variable for memory concerns, but it was worse since it should have kept copies to compute backpropagation.