I am trying to convert a custom model which is a little bit complicated to TorchScript. Converting nn.Modules are okay, but some pure Python functions are giving errors. One of them is that I’m uploading now.
Cannot script Python slice value.
Here I’m uploading only one function as it can be run separately. Please refer to this link for the full model code.
@torch.jit.script
def generate_shift_window_attn_mask(h: int,
w: int,
window_size_h: int,
window_size_w: int,
shift_size_h: int,
shift_size_w: int,
):
# ref: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py
# calculate attention mask for SW-MSA
# h, w = input_resolution
img_mask = torch.zeros((1, h, w, 1)).cuda() # 1 H W 1
h_slices = (slice(0, -window_size_h),
slice(-window_size_h, -shift_size_h),
slice(-shift_size_h, None))
w_slices = (slice(0, -window_size_w),
slice(-window_size_w, -shift_size_w),
slice(-shift_size_w, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = split_feature(img_mask, num_splits=w // window_size_w, channel_last=True)
mask_windows = mask_windows.view(-1, window_size_h * window_size_w)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
return attn_mask
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
/tmp/ipykernel_104851/3733626938.py in <module>
1 @torch.jit.script
----> 2 def generate_shift_window_attn_mask(h: int,
3 w: int,
4 window_size_h: torch.Tensor,
5 window_size_w: torch.Tensor,
~/anaconda3/lib/python3.9/site-packages/torch/jit/_script.py in script(obj, optimize, _frames_up, _rcb, example_inputs)
1341 if _rcb is None:
1342 _rcb = _jit_internal.createResolutionCallbackFromClosure(obj)
-> 1343 fn = torch._C._jit_script_compile(
1344 qualified_name, ast, _rcb, get_default_args(obj)
1345 )
RuntimeError:
Python slice value cannot be used as a value:
File "/tmp/ipykernel_104851/3733626938.py", line 14
# h, w = input_resolution
img_mask = torch.zeros((1, h, w, 1)).cuda() # 1 H W 1
h_slices = (slice(0, -window_size_h),
~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
slice(-window_size_h, -shift_size_h),
slice(-shift_size_h, None))