I want to use roi_pooling module which is defined in fast r-cnn, as follows:
class AdaptiveMaxPool2d(Function):
def init(self, out_w, out_h):
super(AdaptiveMaxPool2d, self).init()
self.out_w = out_w
self.out_h = out_h
def forward(self, input):
output = input.new()
indices = input.new().long()
self.save_for_backward(input)
self.indices = indices
self._backend = type2backend[type(input)]
self._backend.SpatialAdaptiveMaxPooling_updateOutput(
self._backend.library_state, input, output, indices,
self.out_w, self.out_h)
return output
def backward(self, grad_output):
input, = self.saved_tensors
indices = self.indices
grad_input = grad_output.new()
self._backend.SpatialAdaptiveMaxPooling_updateGradInput(
self._backend.library_state, input, grad_output, grad_input,
indices)
return grad_input, None
def adaptive_max_pool(input, size):
return AdaptiveMaxPool2d(size[0],size[1])(input)
def roi_pooling(input, rois, size=(7,7), spatial_scale=1.0):
assert(rois.dim() == 2)
assert(rois.size(1) == 5)
output = []
rois = rois.data.float()
num_rois = rois.size(0)
rois[:,1:].mul_(spatial_scale)
rois = rois.long()
for i in range(num_rois):
roi = rois[i]
im_idx = roi[0]
im = input.narrow(0, im_idx, 1)[..., roi[2]:(roi[4]+1), roi[1]:(roi[3]+1)]
output.append(adaptive_max_pool(im, size))
return torch.cat(output, 0)
my code for call the function is as:
roi_pooling(myfeaturemap, rois, size=(1,1), spatial_scale=1.0/16)
where the type of myfeaturemap is torch.Tensor
but I got errors ,like follows:
File “/home/user/lr/twostage/roi_pooling.py”, line 54, in roi_pooling
output.append(adaptive_max_pool(im, size))
File “/home/user/lr/twostage/roi_pooling.py”, line 37, in adaptive_max_pool
return AdaptiveMaxPool2d(size[0],size[1])(input)
File “/home/user/lr/twostage/roi_pooling.py”, line 21, in forward
self._backend = type2backend[type(input)]
File “/usr/local/lib/python2.7/dist-packages/torch/_thnn/init.py”, line 15, in getitem
return self.backends[name].load()
KeyError: <class ‘torch.Tensor’>
, I don’t have ideas about above error, could anybody help me?