Slow backward using nn.Conv2d as substitution for nn.Linear

I am trying to implement Faster RCNN by myself. For roi_head, I used nn.Conv2d with 7*7 kernal as the substitution for nn.Linear. However, I found that the backward is very slow with nn.Conv2d. It took around 15s backward time with 100 iteration with nn.Conv2d, but it took 3.7s backward time for 100 iteration with nn.Linear. (The backward time is for whole faster rcnn network since it is hard to just time module backprop time).

Why is the time difference so big?

The two code are listed here

class FastRCNNHead(nn.Module):
    def __init__(self, cfg):
        self.class_num = cfg.class_num
        class_num = cfg.class_num
        self.cfg = cfg
        feature_num = 1024
        pool_h = cfg.roi_pool.pool_h
        pool_w = cfg.roi_pool.pool_w
        self.feature_head = nn.Sequential( 
            nn.Linear(cfg.out_feature_num*pool_h*pool_w, feature_num),
            nn.Linear(feature_num, feature_num),
        self.label_head = nn.Linear(feature_num, class_num+1)
        self.bbox_head = nn.Linear(feature_num, 4*(class_num+1))

        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, std=0.01)
                nn.init.constant_(m.bias, 0)

    def forward(self, rois ):

        rois = rois.flatten(start_dim=1)
        features = self.feature_head(rois)
        label_pre = self.label_head(features)
        bbox_pre = self.bbox_head(features)

        return label_pre, bbox_pre
class FastRCNNHeadConv(nn.Module):
    def __init__(self, cfg):
        self.class_num = cfg.class_num
        class_num = cfg.class_num
        self.cfg = cfg
        feature_num = 1024
        pool_h = cfg.roi_pool.pool_h
        pool_w = cfg.roi_pool.pool_w
        self.feature_head = nn.Sequential( 
            nn.Conv2d(in_channels=cfg.out_feature_num, out_channels=feature_num, kernel_size=(pool_h, pool_w), padding=0 ),
            nn.Conv2d(in_channels=feature_num, out_channels=feature_num, kernel_size=1, padding=0 ),
        self.label_head = nn.Conv2d(in_channels=feature_num, out_channels=class_num+1, kernel_size=1, padding=0 )
        self.bbox_head = nn.Conv2d(in_channels=feature_num, out_channels=4*(class_num+1), kernel_size=1, padding=0 )

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.normal_(m.weight, std=0.01)
                nn.init.constant_(m.bias, 0)

    def forward(self, rois ):

        features = self.feature_head(rois)
        label_pre = self.label_head(features)
        label_pre = label_pre.view(label_pre.shape[:2])
        bbox_pre = self.bbox_head(features)
        bbox_pre = bbox_pre.view(bbox_pre.shape[:2])

        return label_pre, bbox_pre