Save output of torch._export.export

I have some trouble using torch._export.export() and was wondering whether somebody could help me out.

First, is there a way to store the output (ExportedProgram) of torch._export.export? When using it, torch._export.export compiles the given function or module and then lets the user use the accelerated version. By looking at the created guards, it seems that the compilation is done similarly to just using torch.compile(), is my understanding correct?
This is very helpful, but is there a way to store this ExportedProgram such that it can be reloaded and used again, for example in a subsequent run?

My second questions would be related to the usage of the ExportedProgram during training. For me, the output of the ExportedProgram has the attribute requires_grad=False and thus does not support the backward function. Is there a way to enable the backward function for the ExportedProgram?

My last question is concerned with the arguments provided to the function to be exported. If I see this correctly, the function to be exported, gets compiled with the arguments provided. However, there are guards introduced which are specialized to the shapes of these arguments. Thus, it can only be used with other inputs that have exactly the same shapes. In order to allow the exported function to handle multiple different shapes, is there a way to feed a list of arguments that the function is subsequently compiled for?

Here is a little test program that showcases my questions:

from typing import List, Tuple, Optional, overload, Union, cast
import torch
import numpy as np
import time
import torch.optim as optim
from torch.nn.parameter import Parameter

class RNNTest():

    def __init__(self, device) -> None:
        pass
    
    def RNNScript(self,
        input,
        param1,
        param2,
        ):

        state1 = torch.zeros(64, 340, dtype=input.dtype, device=input.device)
        
        outs = []

        Wx = input @ param1
        Wx_inp, Wx_rec = torch.tensor_split(Wx, 2, 2)
        for wt_inp, wt_rec in zip(Wx_inp, Wx_rec):
            rec_mul_inp, rec_mul_rec = torch.tensor_split(state1 @ param2, 2, 1)
            input_prev = (wt_inp + rec_mul_inp)
            output_gate = (wt_rec + rec_mul_rec)

            state1 = 1. + input_prev * torch.sigmoid(output_gate)
            outs.append(state1)
        
        outs = torch.stack(outs)

        return outs, (outs)

if __name__ == "__main__":

    input_size = 140
    hidden_size = 340
    batch_size = 64
    use_gpu = True

    forward_times = []
    backward_times = []

    if use_gpu:
        device = torch.device('cuda:0')
    else:
        device = None

    rnn_test = RNNTest(device)
    
    parameters = []

    w_ih = torch.empty((input_size, hidden_size), device=device)
    w_io = torch.empty((input_size, hidden_size), device=device)
    w_i_comb = torch.cat([w_ih,w_io],1)
    w_i_comb.requires_grad_(True)
    parameters.append(w_i_comb)

    w_hh = torch.empty((hidden_size, hidden_size), device=device)
    w_ho = torch.empty((hidden_size, hidden_size), device=device)
    w_h_comb = torch.cat([w_hh,w_ho],1)
    w_h_comb.requires_grad_(True)
    parameters.append(w_h_comb)
    
    
    def count_kernels(guard):
        print("[pt2_compile] guard failed: ", guard)

    rnnscript = torch.compile(rnn_test.RNNScript, dynamic=True, fullgraph=True)

    optimizer = optim.SGD(parameters, 0.1)

    optimizer.zero_grad()
    for execution in range(20):
        start_forward = time.time_ns()
        t_rnd = 120
        inp = torch.rand((t_rnd, batch_size, input_size))
        print(inp.shape)
        if use_gpu:
            inp = inp.cuda()
        
        #Use regular torch.compile
        out, state = rnnscript(inp, w_i_comb, w_h_comb) #out has requires_grad=True

        #Use torch._export.export to comile the function
        val = torch._export.export(rnnscript, [inp, w_i_comb, w_h_comb])
        out_exp, st_exp = val(inp, w_i_comb, w_h_comb)
        loss = 1. - torch.sum(out_exp)
        loss.backward() #Fails because out_exp has requires_grad=False

        inp = torch.rand((120, batch_size, input_size))
        if use_gpu:
            inp = inp.cuda()
        out_exp2, st_exp2 = val(inp, w_i_comb, w_h_comb) #Works as the input shapes are the same as the ones used with torch._export.export

        inp = torch.rand((80, batch_size, input_size))
        if use_gpu:
            inp = inp.cuda()
        out_exp2, st_exp2 = val(inp, w_i_comb, w_h_comb) #Fails because inp has a different shape than the original inp used with torch._export.export

Thank you in advance for your help.

This test might help answer some of your questions https://github.com/pytorch/pytorch/blob/main/test/cpp/aot_inductor/test.py

And some specific answers to your question

  1. export is inference focused, if you’re interested in further finetuning a model we recommend you stick to torch.compile()
  2. If you want to now load the program again you should take a look at the aot inductor test i posted above
  3. You can try dynamic shapes although afaik support is still limited with export

Thank you very much for your response and this information! I will have a look at the testcase