Below is the forward pass of my model. The input x is split about time-dimension (last-dim) which has indices till 250. Below is the code…
from torch_geometric.data import Dataset, Data
from torch_geometric.data import Batch
def forward(self, x):
dynamic_result = None
split_x = torch.split(x, 1, dim=-1) # doing split along temporal dim
for i in range(len(split_x)):
y = torch.squeeze(split_x[i]) # y.shape : (B, C, V)
dynamic_weights = self.get_dynamic_adj_mat(y) # dynamic_weights: (B,num_edges)
y = torch.permute(y , (0,2,1)) # y.shape : (B, V, C)
graphs = []
for j in range(y.shape[0]):
dim, idx = 0, torch.tensor([j], dtype = torch.int32).to(device)
d = Data(x = torch.squeeze(torch.index_select(y, dim=dim, index = idx)),
edge_index=self.edge_indices,
edge_attr=torch.squeeze(torch.index_select(dynamic_weights, dim=dim, index = idx)))
graphs.append(d)
batch = Batch.from_data_list(graphs)
dynamic_h = self.dynamic_gnn_layer(x= batch.x , edge_index= batch.edge_index , edge_attr = batch.edge_attr)
dynamic_h = rearrange(dynamic_h , '(b v) c -> b c v' , b = self.batch_sz , v = self.num_nodes)
dynamic_h = torch.unsqueeze(dynamic_h , dim=3)
# pr.yellow(f'Passed Dynamic_GNN_Layer || dynamic_h.shape : {dynamic_h.shape}')
if dynamic_result is None:
dynamic_result = dynamic_h
else :
dynamic_result = torch.cat([dynamic_result, dynamic_h] , dim=3)
The 2 forward loops are the rate limiting step in training of model. The model is not big, but the for loops increases the complexity…
Is there a way to optimize this code? Like tensorizing things…
Please guide…Thanks in advance