Variable size of Advanced indexing for modules

I am trying to implement a module with variable size tensors (determined only by shapes) for advanced indexing. However, it does not work since Runtime Error of cublas will occur when forwarding.
Here is a toy example.

class Block(nn.Module):
    def __init__(self, output_shape):
        super(Block, self).__init__()
        self.output_shape = output_shape

    def forward(self, x):
        batch_size, channel, height, width = x.shape

        # ndarray type matrix for indexing involving with shapes of inputs and outputs.
        grid = self._get_grid(x.shape, self.output_shape)
        # Variable only related of shapes. No grads requirement.
        dist = self._get_dist(x.shape, self.output_shape)

        coef = torch.gather(x, 2, Variable(torch.cuda.LongTensor(grid)))
        x = torch.bmm(x, coef).bmm(dist)
        return x

Note that I need to modify Block.output_shape before each forwarding. I just wonder if I have other better way to implement such a module without writing a C extension.

Could you calculate the output_shape from x.shape?
If so you could just have forward calculate the required output_shape.