I am using expand function in pytorch to repeat tensors ( cnn feats of size 2048) but somehow my code is slow even though it is running in GPU and on top of that its freezing my whole PC ( graphics becomes laggy). I have also attached a part of my forward method . Below , N= 36( boxes for 1 entry of the batch size ) and B = 16 ( batch size). So the b_i is 16x36x2048 and so is b_j while q_rnn is 16x512.
def forward(self, box_feats,q_feats,box_coords):
enc2,_ = self.QRNN(q_feats.permute(1,0,2))
q_rnn = enc2[-1]
# add coordinates
box_feats = torch.cat([box_feats, box_coords],dim=-1)
N = box_feats.size(1) # number of boxes
B = box_feats.size(0) #batch size
qst = q_rnn.unsqueeze(1).expand(-1,N*N,-1)
b_i = box_feats.unsqueeze(2).expand(-1,-1,N,-1)
b_i = b_i.contiguous().view(B,N**2,-1)
b_j = box_feats.unsqueeze(1).expand(-1,N,-1,-1)
b_j = b_j.contiguous().view(B,N**2,-1)
#print (b_j.size())
# concatenate all together
b_full = torch.cat([b_i,b_j,qst],-1)
bg = self.g1(b_full)
bg = F.relu(bg)
bg = self.g2(bg)
bg = F.relu(bg)