I have some trouble using
torch._export.export() and was wondering whether somebody could help me out.
First, is there a way to store the output (ExportedProgram) of torch._export.export? When using it, torch._export.export compiles the given function or module and then lets the user use the accelerated version. By looking at the created guards, it seems that the compilation is done similarly to just using torch.compile(), is my understanding correct?
This is very helpful, but is there a way to store this ExportedProgram such that it can be reloaded and used again, for example in a subsequent run?
My second questions would be related to the usage of the ExportedProgram during training. For me, the output of the ExportedProgram has the attribute
requires_grad=False and thus does not support the backward function. Is there a way to enable the backward function for the ExportedProgram?
My last question is concerned with the arguments provided to the function to be exported. If I see this correctly, the function to be exported, gets compiled with the arguments provided. However, there are guards introduced which are specialized to the shapes of these arguments. Thus, it can only be used with other inputs that have exactly the same shapes. In order to allow the exported function to handle multiple different shapes, is there a way to feed a list of arguments that the function is subsequently compiled for?
Here is a little test program that showcases my questions:
from typing import List, Tuple, Optional, overload, Union, cast import torch import numpy as np import time import torch.optim as optim from torch.nn.parameter import Parameter class RNNTest(): def __init__(self, device) -> None: pass def RNNScript(self, input, param1, param2, ): state1 = torch.zeros(64, 340, dtype=input.dtype, device=input.device) outs =  Wx = input @ param1 Wx_inp, Wx_rec = torch.tensor_split(Wx, 2, 2) for wt_inp, wt_rec in zip(Wx_inp, Wx_rec): rec_mul_inp, rec_mul_rec = torch.tensor_split(state1 @ param2, 2, 1) input_prev = (wt_inp + rec_mul_inp) output_gate = (wt_rec + rec_mul_rec) state1 = 1. + input_prev * torch.sigmoid(output_gate) outs.append(state1) outs = torch.stack(outs) return outs, (outs) if __name__ == "__main__": input_size = 140 hidden_size = 340 batch_size = 64 use_gpu = True forward_times =  backward_times =  if use_gpu: device = torch.device('cuda:0') else: device = None rnn_test = RNNTest(device) parameters =  w_ih = torch.empty((input_size, hidden_size), device=device) w_io = torch.empty((input_size, hidden_size), device=device) w_i_comb = torch.cat([w_ih,w_io],1) w_i_comb.requires_grad_(True) parameters.append(w_i_comb) w_hh = torch.empty((hidden_size, hidden_size), device=device) w_ho = torch.empty((hidden_size, hidden_size), device=device) w_h_comb = torch.cat([w_hh,w_ho],1) w_h_comb.requires_grad_(True) parameters.append(w_h_comb) def count_kernels(guard): print("[pt2_compile] guard failed: ", guard) rnnscript = torch.compile(rnn_test.RNNScript, dynamic=True, fullgraph=True) optimizer = optim.SGD(parameters, 0.1) optimizer.zero_grad() for execution in range(20): start_forward = time.time_ns() t_rnd = 120 inp = torch.rand((t_rnd, batch_size, input_size)) print(inp.shape) if use_gpu: inp = inp.cuda() #Use regular torch.compile out, state = rnnscript(inp, w_i_comb, w_h_comb) #out has requires_grad=True #Use torch._export.export to comile the function val = torch._export.export(rnnscript, [inp, w_i_comb, w_h_comb]) out_exp, st_exp = val(inp, w_i_comb, w_h_comb) loss = 1. - torch.sum(out_exp) loss.backward() #Fails because out_exp has requires_grad=False inp = torch.rand((120, batch_size, input_size)) if use_gpu: inp = inp.cuda() out_exp2, st_exp2 = val(inp, w_i_comb, w_h_comb) #Works as the input shapes are the same as the ones used with torch._export.export inp = torch.rand((80, batch_size, input_size)) if use_gpu: inp = inp.cuda() out_exp2, st_exp2 = val(inp, w_i_comb, w_h_comb) #Fails because inp has a different shape than the original inp used with torch._export.export
Thank you in advance for your help.