Error: NotImplementedError when backward()

I added a lstm module in SSD. But I encounter a error ‘NotImplementedError’ when backward().
The forward process work well.
The changed code as follow:

class train_target(Function):
    def __init__(self, num_classes, top_k, overlap_thresh, conf_thresh, nms_thresh, use_gpu=True):
        self.num_classes = num_classes
        self.top_k = top_k
        self.threshold = overlap_thresh
        # Parameters used in nms.
        self.nms_thresh = nms_thresh
        if nms_thresh <= 0:
            raise ValueError('nms_threshold must be non negative.')
        self.conf_thresh = conf_thresh
        self.variance = cfg['variance']
        self.use_gpu = use_gpu

    def forward(self, loc_data, conf_data, prior_data):
        """
        Args:
            loc_data: (tensor) Loc preds from loc layers
                Shape: [batch,num_priors*4]
            conf_data: (tensor) Shape: Conf preds from conf layers
                Shape: [batch*num_priors,num_classes]
            prior_data: (tensor) Prior boxes and variances from priorbox layers
                Shape: [1,num_priors,4]
        return:
            rois:(tensor) rois after decoded loc data and nms
                Shape: [batch, top_k, 5]
            loc_pred: (tensor) loc after nms
                Shape: [batch, top_k, 4]
            cls_pred: (tensor) conf_data after nms
                Shape: [batch, top_k, num_classes]
        """

        priors = prior_data
        batch = loc_data.size(0)  # batch size
        print('priors size: ',priors.size())
        priors = priors[:loc_data.size(1), :]
        print('after priors size: ', priors.size() )
        num_priors = (priors.size(0))
        num_classes = self.num_classes

        result = torch.zeros(batch, self.top_k, 1 + 3 * 4 +num_classes)    # 加入score[1], rois[4], loc[4], priors[4]

        conf_preds = conf_data.view(batch, num_priors,
                                    num_classes).transpose(2, 1)  # conf_preds size(num,num_classes,num_pirors)
        conf_data = conf_data.view(batch, num_priors,
                                   num_classes)

        decoded_box = loc_data.new(loc_data.size(0), loc_data.size(1), loc_data.size(2)).zero_()
        for i in range(batch):
            decoded_boxes = decode(loc_data[i], prior_data, self.variance)  # box decode
            # For each class, perform nms
            conf_scores = conf_preds[i].clone()  
            loc_keep = loc_data[i].clone()
            conf_keep = conf_data[i].clone()
            priors_keep = priors.clone()
           # print('priors_keep size: ', priors_keep.size())

            decoded_box[i] = decoded_boxes
            output = []
            for cl in range(1, num_classes):  

                c_mask = conf_scores[cl].gt(self.conf_thresh)  
                scores = conf_scores[cl][c_mask]
                if scores.dim() == 0 or scores.size() == torch.Size([0]):
                    continue
                l_mask = c_mask.unsqueeze(1).expand_as(decoded_boxes)
                s_mask = c_mask.unsqueeze(1).expand_as(conf_keep)

                boxes = decoded_boxes[l_mask].view(-1, 4)
                loc = loc_keep[l_mask].view(-1, 4)

                conf = conf_keep[s_mask].view(-1, num_classes)
                prior_select = priors_keep[l_mask].view(-1, 4)

                # print(conf.size())

                # idx of highest scoring and non-overlapping boxes per class
                ids, count = nms(boxes, scores, self.nms_thresh, self.top_k)
                output.append(
                    torch.cat((scores[ids[:count]].unsqueeze(1),
                               boxes[ids[:count]], loc[ids[:count]], conf[ids[:count]], prior_select[ids[:count]] ), 1)
                )

            res = output[0]
            for j in range(len(output) - 1):
                res = torch.cat((res, output[j + 1]), dim=0)
            #
            #            print('index: ',i,'type',type(res))
            #            print('res size: ',res.size())
            #            print('res',res)

            # sorted by confidence
            # _, indices = res[:, 0].sort(0, descending=True)

            sort_conf = res[:, 0].clone().cpu().numpy()
            res[:, 0] = i

            res = res.cpu().numpy()
            b = np.ascontiguousarray(res).view(np.dtype((np.void, res.dtype.itemsize * res.shape[1])))
            _, idx = np.unique(b, return_index=True)
            keep_res = torch.from_numpy(res[idx])
            print('sort_conf type: ',type(sort_conf))
#            idx = torch.from_numpy(idx)
            sort_val = sort_conf[idx]
            sort_val = torch.from_numpy(sort_val).view(-1, 1)
            _, indices = sort_val[:, 0].sort(0, descending=True)

            res_sel = keep_res[indices][:self.top_k]

            result[i][:res_sel.size(0)] = res_sel

            index1 = torch.from_numpy(np.array(range(0, 5))).cuda()
            index2 = torch.from_numpy(np.array(range(5, 9))).cuda()
            index3 = torch.from_numpy(np.array(range(9, 9 + num_classes))).cuda()
            index4 = torch.from_numpy(np.array(range(9 + num_classes,9 + num_classes+4))).cuda()

            rois = torch.index_select(result, -1, index1)
            #            print('rois: ',rois.size())

            loc = torch.index_select(result, -1, index2)

            cls = torch.index_select(result, -1, index3)
            
            priors_out = torch.index_select(result, -1, index4)
            
            # add output priors
        return rois, loc, cls, priors_out

Hope someone can help me!

You also need to define backward() function similar to forward (refer here: PyTorch: Defining New autograd Functions — PyTorch Tutorials 2.1.1+cu121 documentation)

Thank you!
So it means a class method inherited from Function need a backward().
I misunderstand the method can autograd.

Yes. Class inheriting Function should implement backward()

Thank you! I change the function inherit from nn.Module, then it is working.