I am trying to implement a module with variable size tensors (determined only by shapes) for advanced indexing. However, it does not work since Runtime Error of cublas will occur when forwarding.
Here is a toy example.
class Block(nn.Module): def __init__(self, output_shape): super(Block, self).__init__() self.output_shape = output_shape def forward(self, x): batch_size, channel, height, width = x.shape # ndarray type matrix for indexing involving with shapes of inputs and outputs. grid = self._get_grid(x.shape, self.output_shape) # Variable only related of shapes. No grads requirement. dist = self._get_dist(x.shape, self.output_shape) coef = torch.gather(x, 2, Variable(torch.cuda.LongTensor(grid))) x = torch.bmm(x, coef).bmm(dist) return x
Note that I need to modify
Block.output_shape before each forwarding. I just wonder if I have other better way to implement such a module without writing a C extension.