I am using expand function in pytorch to repeat tensors ( cnn feats of size 2048) but somehow my code is slow even though it is running in GPU and on top of that its freezing my whole PC ( graphics becomes laggy). I have also attached a part of my forward method . Below , N= 36( boxes for 1 entry of the batch size ) and B = 16 ( batch size). So the b_i is 16x36x2048 and so is b_j while q_rnn is 16x512.
def forward(self, box_feats,q_feats,box_coords): enc2,_ = self.QRNN(q_feats.permute(1,0,2)) q_rnn = enc2[-1] # add coordinates box_feats = torch.cat([box_feats, box_coords],dim=-1) N = box_feats.size(1) # number of boxes B = box_feats.size(0) #batch size qst = q_rnn.unsqueeze(1).expand(-1,N*N,-1) b_i = box_feats.unsqueeze(2).expand(-1,-1,N,-1) b_i = b_i.contiguous().view(B,N**2,-1) b_j = box_feats.unsqueeze(1).expand(-1,N,-1,-1) b_j = b_j.contiguous().view(B,N**2,-1) #print (b_j.size()) # concatenate all together b_full = torch.cat([b_i,b_j,qst],-1) bg = self.g1(b_full) bg = F.relu(bg) bg = self.g2(bg) bg = F.relu(bg)