I have some trouble using torch._export.export()
and was wondering whether somebody could help me out.
First, is there a way to store the output (ExportedProgram) of torch._export.export? When using it, torch._export.export compiles the given function or module and then lets the user use the accelerated version. By looking at the created guards, it seems that the compilation is done similarly to just using torch.compile(), is my understanding correct?
This is very helpful, but is there a way to store this ExportedProgram such that it can be reloaded and used again, for example in a subsequent run?
My second questions would be related to the usage of the ExportedProgram during training. For me, the output of the ExportedProgram has the attribute requires_grad=False
and thus does not support the backward function. Is there a way to enable the backward function for the ExportedProgram?
My last question is concerned with the arguments provided to the function to be exported. If I see this correctly, the function to be exported, gets compiled with the arguments provided. However, there are guards introduced which are specialized to the shapes of these arguments. Thus, it can only be used with other inputs that have exactly the same shapes. In order to allow the exported function to handle multiple different shapes, is there a way to feed a list of arguments that the function is subsequently compiled for?
Here is a little test program that showcases my questions:
from typing import List, Tuple, Optional, overload, Union, cast
import torch
import numpy as np
import time
import torch.optim as optim
from torch.nn.parameter import Parameter
class RNNTest():
def __init__(self, device) -> None:
pass
def RNNScript(self,
input,
param1,
param2,
):
state1 = torch.zeros(64, 340, dtype=input.dtype, device=input.device)
outs = []
Wx = input @ param1
Wx_inp, Wx_rec = torch.tensor_split(Wx, 2, 2)
for wt_inp, wt_rec in zip(Wx_inp, Wx_rec):
rec_mul_inp, rec_mul_rec = torch.tensor_split(state1 @ param2, 2, 1)
input_prev = (wt_inp + rec_mul_inp)
output_gate = (wt_rec + rec_mul_rec)
state1 = 1. + input_prev * torch.sigmoid(output_gate)
outs.append(state1)
outs = torch.stack(outs)
return outs, (outs)
if __name__ == "__main__":
input_size = 140
hidden_size = 340
batch_size = 64
use_gpu = True
forward_times = []
backward_times = []
if use_gpu:
device = torch.device('cuda:0')
else:
device = None
rnn_test = RNNTest(device)
parameters = []
w_ih = torch.empty((input_size, hidden_size), device=device)
w_io = torch.empty((input_size, hidden_size), device=device)
w_i_comb = torch.cat([w_ih,w_io],1)
w_i_comb.requires_grad_(True)
parameters.append(w_i_comb)
w_hh = torch.empty((hidden_size, hidden_size), device=device)
w_ho = torch.empty((hidden_size, hidden_size), device=device)
w_h_comb = torch.cat([w_hh,w_ho],1)
w_h_comb.requires_grad_(True)
parameters.append(w_h_comb)
def count_kernels(guard):
print("[pt2_compile] guard failed: ", guard)
rnnscript = torch.compile(rnn_test.RNNScript, dynamic=True, fullgraph=True)
optimizer = optim.SGD(parameters, 0.1)
optimizer.zero_grad()
for execution in range(20):
start_forward = time.time_ns()
t_rnd = 120
inp = torch.rand((t_rnd, batch_size, input_size))
print(inp.shape)
if use_gpu:
inp = inp.cuda()
#Use regular torch.compile
out, state = rnnscript(inp, w_i_comb, w_h_comb) #out has requires_grad=True
#Use torch._export.export to comile the function
val = torch._export.export(rnnscript, [inp, w_i_comb, w_h_comb])
out_exp, st_exp = val(inp, w_i_comb, w_h_comb)
loss = 1. - torch.sum(out_exp)
loss.backward() #Fails because out_exp has requires_grad=False
inp = torch.rand((120, batch_size, input_size))
if use_gpu:
inp = inp.cuda()
out_exp2, st_exp2 = val(inp, w_i_comb, w_h_comb) #Works as the input shapes are the same as the ones used with torch._export.export
inp = torch.rand((80, batch_size, input_size))
if use_gpu:
inp = inp.cuda()
out_exp2, st_exp2 = val(inp, w_i_comb, w_h_comb) #Fails because inp has a different shape than the original inp used with torch._export.export
Thank you in advance for your help.