I added a lstm module in SSD. But I encounter a error ‘NotImplementedError’ when backward().
The forward process work well.
The changed code as follow:
class train_target(Function):
def __init__(self, num_classes, top_k, overlap_thresh, conf_thresh, nms_thresh, use_gpu=True):
self.num_classes = num_classes
self.top_k = top_k
self.threshold = overlap_thresh
# Parameters used in nms.
self.nms_thresh = nms_thresh
if nms_thresh <= 0:
raise ValueError('nms_threshold must be non negative.')
self.conf_thresh = conf_thresh
self.variance = cfg['variance']
self.use_gpu = use_gpu
def forward(self, loc_data, conf_data, prior_data):
"""
Args:
loc_data: (tensor) Loc preds from loc layers
Shape: [batch,num_priors*4]
conf_data: (tensor) Shape: Conf preds from conf layers
Shape: [batch*num_priors,num_classes]
prior_data: (tensor) Prior boxes and variances from priorbox layers
Shape: [1,num_priors,4]
return:
rois:(tensor) rois after decoded loc data and nms
Shape: [batch, top_k, 5]
loc_pred: (tensor) loc after nms
Shape: [batch, top_k, 4]
cls_pred: (tensor) conf_data after nms
Shape: [batch, top_k, num_classes]
"""
priors = prior_data
batch = loc_data.size(0) # batch size
print('priors size: ',priors.size())
priors = priors[:loc_data.size(1), :]
print('after priors size: ', priors.size() )
num_priors = (priors.size(0))
num_classes = self.num_classes
result = torch.zeros(batch, self.top_k, 1 + 3 * 4 +num_classes) # 加入score[1], rois[4], loc[4], priors[4]
conf_preds = conf_data.view(batch, num_priors,
num_classes).transpose(2, 1) # conf_preds size(num,num_classes,num_pirors)
conf_data = conf_data.view(batch, num_priors,
num_classes)
decoded_box = loc_data.new(loc_data.size(0), loc_data.size(1), loc_data.size(2)).zero_()
for i in range(batch):
decoded_boxes = decode(loc_data[i], prior_data, self.variance) # box decode
# For each class, perform nms
conf_scores = conf_preds[i].clone()
loc_keep = loc_data[i].clone()
conf_keep = conf_data[i].clone()
priors_keep = priors.clone()
# print('priors_keep size: ', priors_keep.size())
decoded_box[i] = decoded_boxes
output = []
for cl in range(1, num_classes):
c_mask = conf_scores[cl].gt(self.conf_thresh)
scores = conf_scores[cl][c_mask]
if scores.dim() == 0 or scores.size() == torch.Size([0]):
continue
l_mask = c_mask.unsqueeze(1).expand_as(decoded_boxes)
s_mask = c_mask.unsqueeze(1).expand_as(conf_keep)
boxes = decoded_boxes[l_mask].view(-1, 4)
loc = loc_keep[l_mask].view(-1, 4)
conf = conf_keep[s_mask].view(-1, num_classes)
prior_select = priors_keep[l_mask].view(-1, 4)
# print(conf.size())
# idx of highest scoring and non-overlapping boxes per class
ids, count = nms(boxes, scores, self.nms_thresh, self.top_k)
output.append(
torch.cat((scores[ids[:count]].unsqueeze(1),
boxes[ids[:count]], loc[ids[:count]], conf[ids[:count]], prior_select[ids[:count]] ), 1)
)
res = output[0]
for j in range(len(output) - 1):
res = torch.cat((res, output[j + 1]), dim=0)
#
# print('index: ',i,'type',type(res))
# print('res size: ',res.size())
# print('res',res)
# sorted by confidence
# _, indices = res[:, 0].sort(0, descending=True)
sort_conf = res[:, 0].clone().cpu().numpy()
res[:, 0] = i
res = res.cpu().numpy()
b = np.ascontiguousarray(res).view(np.dtype((np.void, res.dtype.itemsize * res.shape[1])))
_, idx = np.unique(b, return_index=True)
keep_res = torch.from_numpy(res[idx])
print('sort_conf type: ',type(sort_conf))
# idx = torch.from_numpy(idx)
sort_val = sort_conf[idx]
sort_val = torch.from_numpy(sort_val).view(-1, 1)
_, indices = sort_val[:, 0].sort(0, descending=True)
res_sel = keep_res[indices][:self.top_k]
result[i][:res_sel.size(0)] = res_sel
index1 = torch.from_numpy(np.array(range(0, 5))).cuda()
index2 = torch.from_numpy(np.array(range(5, 9))).cuda()
index3 = torch.from_numpy(np.array(range(9, 9 + num_classes))).cuda()
index4 = torch.from_numpy(np.array(range(9 + num_classes,9 + num_classes+4))).cuda()
rois = torch.index_select(result, -1, index1)
# print('rois: ',rois.size())
loc = torch.index_select(result, -1, index2)
cls = torch.index_select(result, -1, index3)
priors_out = torch.index_select(result, -1, index4)
# add output priors
return rois, loc, cls, priors_out
Hope someone can help me!