Scripted training

Hi,
I am trying to do an online optimization of some parameters using torch.jit.script online, and need to torch.jit.export my operations. When I build the training procedure manually it works, but when I try to use a torch.optim it does not seem to be supported (see the added code at the bottom).

In the module, manual_optim_step(x) works as expected (and minimizes the loss), but when I try to script the module with the @torch.jit.export decorator around optim_step(x) uncommented I get the following error:

RuntimeError: 
Module 'MyOptimModule' has no attribute 'optim' (This attribute exists on the Python module, but we failed to convert Python type: 'torch.optim.sgd.SGD' to a TorchScript type. Only tensors and (possibly nested) tuples of tensors, lists, or dictsare supported as inputs or outputs of traced functions, but instead got value of type SGD.. Its type was inferred; try adding a type annotation for the attribute.):
  File "<ipython-input-5-4d7f596ef214>", line 44
        loss = self.forward(x).sum().abs()
        
        self.optim.zero_grad()
        ~~~~~~~~~~ <--- HERE
        loss.backward()

I do assume this simply means that torch.optim is not supported in jit.script yet.
So my question is:
(a) is there any plans for this?
and (b) what is the best approach to train parameters during jit.script? Is it simply to reimplement gradient based optimization as I do in manual_optim_step()?

import torch
x = torch.rand((2,)) # Dummy input data


class MyOptimModule(torch.nn.Module):
    def __init__(self, N, M):
        super(MyOptimModule, self).__init__()
        self.linear = torch.nn.Linear(N, M)
        self.optim = torch.optim.SGD(self.parameters(), lr=0.1)
    
    def forward(self, input: torch.Tensor):
        output = self.linear(input)
        return output
    
    @torch.jit.export
    def manual_optim_step(self, x: torch.Tensor, num_epochs: int=100):
        """ 
        Inspired by the following code:
        https://pytorch.org/tutorials/recipes/distributed_optim_torchscript.html
        """
        print(f" parsum before: \t {self.linear.weight.sum()}")
        # create some dummy loss:
        loss = self.forward(x).abs().sum()
        #print(loss)
        loss.backward()
        with torch.no_grad():
            for p in [self.linear.weight, self.linear.bias]:
                p.add_(-0.01*p.grad)
                if p.grad is not None:
                    p.grad.detach_()
                    p.grad.zero_()
                #self.linear.weight.data.add_(-0.01*self.linear.weight.grad)
        print(f" parsum after:   \t {self.linear.weight.sum()}")
        
    @torch.jit.export
    def optim_step(self, x: torch.Tensor):
        print(f" parsum before: \t {self.linear.weight.sum()}")
        # create some dummy loss:
        loss = self.forward(x).sum().abs()
        
        self.optim.zero_grad()
        loss.backward()
        print(f" gradients:   \t {self.linear.weight.grad}")

        self.optim.step()
        print(f" parsum after:   \t {self.linear.weight.sum()}")

# %%
model_with_optim = MyOptimModule(2, 3)

scripted_module = torch.jit.script(model_with_optim)

scripted_module.manual_optim_step(x)