I am trying to implement a module with variable size tensors (determined only by shapes) for advanced indexing. However, it does not work since Runtime Error of cublas will occur when forwarding.
Here is a toy example.
class Block(nn.Module):
def __init__(self, output_shape):
super(Block, self).__init__()
self.output_shape = output_shape
def forward(self, x):
batch_size, channel, height, width = x.shape
# ndarray type matrix for indexing involving with shapes of inputs and outputs.
grid = self._get_grid(x.shape, self.output_shape)
# Variable only related of shapes. No grads requirement.
dist = self._get_dist(x.shape, self.output_shape)
coef = torch.gather(x, 2, Variable(torch.cuda.LongTensor(grid)))
x = torch.bmm(x, coef).bmm(dist)
return x
Note that I need to modify Block.output_shape
before each forwarding. I just wonder if I have other better way to implement such a module without writing a C extension.