Python slice value cannot be used as a value:

I am trying to convert a custom model which is a little bit complicated to TorchScript. Converting nn.Modules are okay, but some pure Python functions are giving errors. One of them is that I’m uploading now.

Cannot script Python slice value.

Here I’m uploading only one function as it can be run separately. Please refer to this link for the full model code.

@torch.jit.script
def generate_shift_window_attn_mask(h: int,
                                    w: int,
                                    window_size_h: int,
                                    window_size_w: int,
                                    shift_size_h: int,
                                    shift_size_w: int,
                                    ):
    # ref: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py
    # calculate attention mask for SW-MSA
    # h, w = input_resolution
    img_mask = torch.zeros((1, h, w, 1)).cuda()  # 1 H W 1
    h_slices = (slice(0, -window_size_h),
                slice(-window_size_h, -shift_size_h),
                slice(-shift_size_h, None))
    w_slices = (slice(0, -window_size_w),
                slice(-window_size_w, -shift_size_w),
                slice(-shift_size_w, None))
    cnt = 0
    for h in h_slices:
        for w in w_slices:
            img_mask[:, h, w, :] = cnt
            cnt += 1

    mask_windows = split_feature(img_mask, num_splits=w // window_size_w, channel_last=True)

    mask_windows = mask_windows.view(-1, window_size_h * window_size_w)
    attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
    attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))

    return attn_mask
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
/tmp/ipykernel_104851/3733626938.py in <module>
      1 @torch.jit.script
----> 2 def generate_shift_window_attn_mask(h: int,
      3                                     w: int,
      4                                     window_size_h: torch.Tensor,
      5                                     window_size_w: torch.Tensor,

~/anaconda3/lib/python3.9/site-packages/torch/jit/_script.py in script(obj, optimize, _frames_up, _rcb, example_inputs)
   1341         if _rcb is None:
   1342             _rcb = _jit_internal.createResolutionCallbackFromClosure(obj)
-> 1343         fn = torch._C._jit_script_compile(
   1344             qualified_name, ast, _rcb, get_default_args(obj)
   1345         )

RuntimeError: 
Python slice value cannot be used as a value:
  File "/tmp/ipykernel_104851/3733626938.py", line 14
    # h, w = input_resolution
    img_mask = torch.zeros((1, h, w, 1)).cuda()  # 1 H W 1
    h_slices = (slice(0, -window_size_h),
                ~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
                slice(-window_size_h, -shift_size_h),
                slice(-shift_size_h, None))

Dear @ptrblck, I have always watched you giving good solutions to jit and pytorch related problems on the forum. How do you thinks about the above problem?