How to save a wrapped transformer model using torch.script.trace?

I have a custom method that wraps around a sequence classification model and I’m having trouble saving the wrapped model using torch.jit.trace

When I run the following line, I get this ambiguous runtime error:

torch.jit.trace(IGWrapper(model=model), (input_ids, attention_mask, baseline))

RuntimeError: 0 INTERNAL ASSERT FAILED at “…/torch/csrc/jit/ir/alias_analysis.cpp”:616, please report a bug to PyTorch. We don’t have an op for aten::select but it isn’t a special case. Argument types: Tensor?, int, int,

Candidates:
aten::select.Dimname(Tensor(a) self, str dim, int index) → Tensor(a)
aten::select.int(Tensor(a) self, int dim, SymInt index) → Tensor(a)
aten::select.t(t list, int idx) → t(*)

Here is my method for context:

@torch.jit.script
def calc_integral(grads):
    # approx integral
    grads = (grads[:-1] + grads[1:]) / 2
    avg_grad = grads.mean(0)
    return avg_grad

class IGWrapper(nn.Module):
  def __init__(self, model):
    super().__init__()
    self.model = model.eval()

  def forward(self, 
              input_ids: Tensor, 
              attention_mask: Tensor,
              baseline: Tensor):
    torch.set_grad_enabled(True)
    input_embed = self.model.base_model.embeddings.word_embeddings(input_ids)

    copy_embed = torch.clone(input_embed)

    if baseline is None:
    # create baseline
      baseline = torch.zeros_like(copy_embed)

    grads = []

    num_steps = 5
    for step in range(num_steps + 1):
      print(f"step: {step}/{num_steps}")
      input_embed.data = baseline + step/num_steps * (copy_embed - baseline)
      outputs = self.model(input_ids, attention_mask, output_hidden_states=True, output_attentions=True)

      logits, hidden_states = outputs.logits, outputs.hidden_states

      # calculate the derivates of the output embeddings
      out_embed = hidden_states[0]

      @torch.jit.script
      def grad_fn(logits: Tensor, out_embed: Tensor) -> Optional[Tensor]:
          grad_outputs: List[Optional[torch.Tensor]] = [torch.ones_like(logits)]
          return torch.autograd.grad([logits], [out_embed], grad_outputs=grad_outputs, create_graph=True)[0]
      
      g = grad_fn(logits, out_embed)

      grads.append(g[0])

    # stack grads along first dimension to create a new tensor
    grads = torch.stack(grads)

    avg_grad = calc_integral(grads)

    integrated_grads = out_embed * avg_grad

    scores = torch.sqrt((integrated_grads ** 2).sum(-1))

    # normalize scores
    max_s, min_s = scores.max(1, True).values, scores.min(1, True).values

    normalized_scores = (scores - min_s) / (max_s - min_s)
    return normalized_scores[0].detach()
    

FYI torch.jit is under maintenance mode so recommendation is to switch to using torch.export if you want no python or torch.compile() if you just care about speedups