I am working on an nn.Module.
The forward method of the module takes two tensors as input. It first apply a CNN to both of theses inputs. It need then to concatenate the first output with all the lines inside the second output.
What I did for now is the following:
class Model(nn.class):
def __init__(self):
super().__init__()
self.layer1 = SomeModule1()
self.layer2 = SomeModule2()
def forward(self,x : torch.Tensor(n), y : torch.Tensor(m,n)):
x = self.layer1(x)
y = self.layer1(y)
cat = torch.zeros(2*n,m)
for i in range(y.size(0)):
cat[i] = torch.cat(x,y[i])
return self.layer2(cat)
the code work. However I feel that there is a better way to achieve it. I have got some grad_fn= in the tensor grad. I am note sure what it is exactly but I have the feeling that it is not efficient.
What would be the best efficient solution for the concatenation ?
Hi, here there are 3 solutions - the code is a little chaotic, but it should be understandable :
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
class Model(nn.Module):
def __init__(self):
super().__init__()
self.layer1 = nn.Linear(10, 20)
self.layer2 = nn.Linear(40, 30)
def forwardA(self,x,y):
x = self.layer1(x)
y = self.layer1(y)
cat = torch.zeros((y.shape[0], 20*2))
for i in range(y.shape[0]):
cat[i] = torch.cat((x, y[i]))
return self.layer2(cat)
def forwardB(self,x,y):
x = self.layer1(x)
y = self.layer1(y)
results = []
for i in range(y.shape[0]):
results.append(torch.cat((x, y[i])))
cat = torch.stack(results)
return self.layer2(cat)
def forwardC(self,x,y):
x = self.layer1(x)
y = self.layer1(y)
x = x.repeat(y.shape[0], 1)
cat = torch.cat((x,y), dim=1)
return self.layer2(cat)
net = Model()
#########
x = torch.rand((10))
y = torch.rand((5, 10))
resultA = net.forwardA(x,y)
resultB = net.forwardB(x,y)
resultC = net.forwardC(x,y)
if resultA.equal(resultB):
print("ok_AB")
if resultA.equal(resultC):
print("ok_AC")
if resultB.equal(resultC):
print("ok_BC")
#########
start = time.time()
for i in range(1000):
x = torch.rand((10))
y = torch.rand((5, 10))
net.forwardA(x,y)
end = time.time()
print("Method A:", end - start)
start = time.time()
for i in range(1000):
x = torch.rand((10))
y = torch.rand((5, 10))
net.forwardB(x,y)
end = time.time()
print("Method B:", end - start)
start = time.time()
for i in range(1000):
x = torch.rand((10))
y = torch.rand((5, 10))
net.forwardC(x,y)
end = time.time()
print("Method C:", end - start)