Tracing sliced Tensors during forward function

I’m trying to convert a Temporal Shift Module model for action recognition in video into an ONNX format. I have worked with ONNX before and successfully converted models to the format. However, the model that I want to convert now performs some slicing operations during the forward pass.

Basically the chunk of code that is giving the headaches is the following:

out = torch.zeros_like(x)
out[:, :-1, :fold] = x[:, 1:, :fold]  # shift left
out[:, 1:, fold: 2 * fold] = x[:, :-1, fold: 2 * fold]  # shift right
out[:, :, 2 * fold:] = x[:, :, 2 * fold:]  # not shift
  • I have read and tested the different suggestions in the github issues of TSM, generally somebody reacts saying “it works for me”. I can assure you, that is a lie. None of the suggestions work.
  • Secondly, I came across using the ScriptModule of jit, but that doesn’t work either since TemporalShift objects are not callable.
  • Thirdly, I came across this onnx “slice” module but cannot seem to get this one working. The documentation is also not very clear in my opinion.

Anyway, I’m really hitting a wall here and I’d appreciate it if somebody could push me in the right direction.