FasterRCNN with modified predictor

Hi,

I’m trying to extend the FastRCNNPredictor with a pretrained classifier.
The classifier predicts the class with the output of the backbone (before fpn), which is stored with a hook. It’s output is torch.Size([bs, num_classes])

I experimented with different batch sizes and expected the input of the predictor as torch.Size([bs*512, 1024]), since I’m using a representation_size of 1024.

This is mostly but not always the case. Is there a way to determine to which image the input to FastRCNNPredictor corresponds? I have to reshape the tensor, in order to concat the results of the classifier.

Here is the code:

class ModFastRCNNPredictor(nn.Module):
    
    def __init__(self, in_channels, num_classes, classifier):
        super(ModFastRCNNPredictor, self).__init__()
        self.cls_score = nn.Linear(in_channels+num_classes, num_classes)
        self.bbox_pred = nn.Linear(in_channels, num_classes * 4)
        self.classifier = classifier
        self.res_hook = []

    def forward(self, x):
        #print(x.shape)
        if x.dim() == 4:
            assert list(x.shape[2:]) == [1, 1]
        x = x.flatten(start_dim=1)  # torch.Size([bs*512, 1024]) in train, torch.Size([bs*1000, 1024]) in eval
        backbone_out = self.res_hook[0]
        self.classifier.eval()
        x1 = self.classifier(backbone_out) # torch.Size([bs, num_c])
        x1 = torch.sigmoid(x1)
        bs, num_c = x1.shape
        try:
            a = x.view(bs,-1,x.shape[-1]) # torch.Size([bs, 512, 1024]) (train)
            _,d,e = a.shape
            b = torch.stack([x1]*d, dim=1)
            c = torch.cat([a,b],dim=2) # torch.Size([bs, 512, 1024+num_c]) (train)
            c = c.view(c.shape[0]*c.shape[1],-1)
        except Exception as e:
            a,_ = x.shape #123,1024
            c = torch.cat([x,torch.zeros(a,num_c, device=x.device)], dim=1)
            print(e)
            
        scores = self.cls_score(c)
        bbox_deltas = self.bbox_pred(x)

        return scores, bbox_deltas
    
    def hook_fn(self, m, i, o):
        self.res_hook = [o]

Regards,
Robert