Roi Pool using AdaptiveMaxPool

Although ROI Pooling is now present in torchvision as a layer, I need to implement it for 3d. As such, I think I can make use of the AdaptiveMaxPool3D layer. However, considering that this might have issues as mentioned here and here, I have made some changes to the code. However, it is still giving me values which are less than expected. I am attaching the source code here

def pad_input(im, output_size):
    bs, ch, w, h, l = im.shape
    if w >= output_size[0] and h >= output_size[1] and l >= output_size[2]:
        return im
    padding_tuple = (0, output_size[2] - l, 0, output_size[1] - h, 0, output_size[0] - w)
    im = F.pad(im, padding_tuple, mode="replicate")
    return im


def roi_pool_simple(input, rois, output_size, spatial_scale):
    """
    input: B, C, W, H, L
    rois : num_roi, 6
    output_size: (4, 4, 4)
    spatial_scale: 1/4, roi in input space so we need a way to downsample to specific shape
    """
    output = []
    rois = rois.clone()
    rois_num = rois.size(0)
    max_w, max_h, max_l = input.shape[2], input.shape[3], input.shape[4]
    rois.mul_(spatial_scale)
    # Here bs = 1 so we directly use it in the index
    for i in range(rois_num):
        roi_start_w, roi_start_h, roi_start_l, roi_end_w, roi_end_h, roi_end_l = rois[i]
        # round down the start indices
        roi_start_w, roi_start_h, roi_start_l = math.floor(roi_start_w), math.floor(roi_start_h), math.floor(roi_start_l)
        # round up the end indices
        roi_end_w, roi_end_h, roi_end_l = math.ceil(roi_end_w), math.ceil(roi_end_h), math.ceil(roi_end_l)
        # We would have empty channels if the start index is greater than the dimensions. So,
        if roi_start_w >= max_w or roi_start_h >= max_h or roi_start_l >= max_l:
            output.append(torch.zeros(output_size))
        else:
            im = input[..., roi_start_w: roi_end_w, roi_start_h:roi_end_h, roi_start_l:roi_end_l]
            im = pad_input(im, output_size)
            output.append(F.adaptive_max_pool3d(im, output_size))
    return torch.cat(output, dim=0)