For example, there is a 3D tensor whose dimseion is (10, 5, 100), (TimeStep, BS, HiddenDimension). I wanna slice the tensor along the TimeStep dimension, with a tensor whose shape is (8, 5, 100) as output . The problem is, for each sample in batch the selective index is different, namely a mask matrix with shape (8, 5). One intuitive solution is as follows:
tmp = torch.zeros(8, 5, 100) for i in range(5): tmp[:,i, :] = input_tensor[mask_matrix[:, i], i, :] input_tensor = tmp
The operation will be done in Module’s forward func, which I think may slow down the computation. So, any suggestion about effieient solutions?