I have a simple module that given a 2D tensor, multiplies it by another 2D tensor
M fixed in advance (element wise), then computes a weighted sum. I use a loop to multiply the weights to the elements of the sum. There is probably a better way to do so, but this is only a simple example of my more complicated problem which really requires this kind of loop.
import torch import torch.nn.functional as F class NetSimple(torch.nn.Module): def __init__(self,M): super(NetSimple, self).__init__() self.I= 30 with torch.no_grad(): self.M = torch.nn.Parameter(M) self.weight = torch.nn.Parameter(torch.randn(self.I)) def forward(self, x): # x.shape should be 1, grayscale images. prod = torch.zeros(x.shape,1,x.shape,x.shape).type_as(x) for ii in range(self.I): prod += self.weight[ii]*(self.M*x) x = F.relu(prod) return x N=1024 # Input size M=torch.randn(N,N) model = NetSimple(M) model = model.cuda() inp = torch.randn(10,1,N,N) inp = inp.cuda() out = model(inp)
Here, my problem is that the memory required is very important. It seems that a new copy of
M is made for each elements in the loop. It takes about 1,5Go in GPU.
My question is, is there any nice memory way to multiply several times a tensor by the same fixed tensor without seeing the needed memory being scaled with the number of multiplication made?
Thank you in advance.