Batch size=6 running in 4*3090, I got different batch size between input and output

My code can be runned in batch size=8、15, but it cannot be runned in batch size=6.
I didnot change the batch size in my model, it is so weird.

Here is the code, when the batch size is 6, it will be crops into 4 pieces, so the input bs of the module is 6*4=24, but I get output bs=18.

class Qtransfer_sp5(nn.Module):
    def __init__(self):
        super().__init__()
        ch_f = 56
        self.b1 = ResBlock(9, 128, 128, stride=2)
        self.b1_1 = ResBlock(128, 128, 128)
        self.b1_2 = ResBlock(128, 128, 128)

        self.b2 = ResBlock(128, 256, 256, stride=2)
        self.b2_1 = ResBlock(256, 256, 256)
        self.b2_2 = ResBlock(256, 256, 256)
        
        self.b3 = ResBlock(256+4096, 512, 512, stride=2)
        self.b3_1 = ResBlock(512, 512, 512)
        self.b3_2 = ResBlock(512, 512, 18)
        
        self.getQ1 = nn.Sequential(self.b1,
                                   self.b1_1,
                                   self.b1_2)
        
        self.getQ2 = nn.Sequential(self.b2,
                                   self.b2_1,
                                   self.b2_2)
        
        self.getQ3 = nn.Sequential(self.b3,
                                   self.b3_1,
                                   self.b3_2)

    
    def get_Qfeats(self,x,probs):
        b,c,h,w=probs.shape#bs,9,448*448
        x1=self.getQ1(probs)

        x2=self.getQ2(x1)
        feat = F.interpolate(
                x, (int(h/4), int(w/4)), mode='bilinear', align_corners=False)

        cat=torch.cat([feat.detach(),x2],dim=1)
        x3=self.getQ3(cat)
        return x3

    
    def get_sp_cam(self,logits,deconv_para):
        bg= upfeat(logits[:,0:1],deconv_para[:,:9],1,1)
        fg= upfeat(logits[:,1:],deconv_para[:,9:],1,1)
        logits =torch.cat([bg,fg],dim=1)
        return logits

    def forward(self, feat, logits, probs,cam_map):
        batch_size, c_x, h_x, w_x = feat.shape
        b,c,w,h=probs.shape#bs,9,448*448
        

        
        feats_parameters= self.get_Qfeats(feat,probs)#bs,18,56*56
        bg_para=get_noliner(F.softmax(feats_parameters[:,:9],dim=1))#torch.sum(fg_aff).max()# fg_aff[0,:,10:20,10:20].detach().cpu().numpy()
        fg_para=get_noliner(F.softmax(feats_parameters[:,9:],dim=1))
        deconv_parameters= torch.cat([bg_para,fg_para],dim=1)
        feats = self.get_sp_cam(cam_map,deconv_parameters)
        feats = F.interpolate(feats,(h_x, w_x)).type_as(feats)
        
        return feats

My torch version is ‘1.8.1+cu111’