Hello,
I have a simple module that given a 2D tensor, multiplies it by another 2D tensor M
fixed in advance (element wise), then computes a weighted sum. I use a loop to multiply the weights to the elements of the sum. There is probably a better way to do so, but this is only a simple example of my more complicated problem which really requires this kind of loop.
import torch
import torch.nn.functional as F
class NetSimple(torch.nn.Module):
def __init__(self,M):
super(NetSimple, self).__init__()
self.I= 30
with torch.no_grad():
self.M = torch.nn.Parameter(M)
self.weight = torch.nn.Parameter(torch.randn(self.I))
def forward(self, x):
# x.shape[1] should be 1, grayscale images.
prod = torch.zeros(x.shape[0],1,x.shape[2],x.shape[3]).type_as(x)
for ii in range(self.I):
prod += self.weight[ii]*(self.M*x)
x = F.relu(prod)
return x
N=1024 # Input size
M=torch.randn(N,N)
model = NetSimple(M)
model = model.cuda()
inp = torch.randn(10,1,N,N)
inp = inp.cuda()
out = model(inp)
Here, my problem is that the memory required is very important. It seems that a new copy of M
is made for each elements in the loop. It takes about 1,5Go in GPU.
My question is, is there any nice memory way to multiply several times a tensor by the same fixed tensor without seeing the needed memory being scaled with the number of multiplication made?
Thank you in advance.