Is it possible to convert PyTorch multimodal CVAE to TensorRT?

Description

Hi, I’m working on behavioral prediction code which uses multimodal CVAE (link:GitHub - StanfordASL/Trajectron-plus-plus: Code accompanying the ECCV 2020 paper “Trajectron++: Dynamically-Feasible Trajectory Forecasting With Heterogeneous Data” by Tim Salzmann*, Boris Ivanovic*, Punarjay Chakravarty, and Marco Pavone (* denotes equal contribution).) in examples the common approaches is pytorch to onnx then onnx to tensorrt. But when the model is decode-encode structured and multimodal, I could not find the way how to use torch.onnx.export() function.
pytorch model link: Trajectron-plus-plus/experiments/nuScenes/models/int_ee at eccv2020 · StanfordASL/Trajectron-plus-plus · GitHub

torch.onnx.export() function accepts these arguments:

torch.onnx.export(torch_model, # model being run
x, # model input (or a tuple for multiple inputs)
“super_resolution.onnx”, # where to save the model (can be a file or file-like object)
export_params=True, # store the trained parameter weights inside the model file
opset_version=10, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names = [‘input’], # the model’s input names
output_names = [‘output’], # the model’s output names
dynamic_axes={‘input’ : {0 : ‘batch_size’}, # variable length axes
‘output’ : {0 : ‘batch_size’}})

According to informations that I indicated below, what should input_names, output_names, and dynamic axes be?

As it is shown at below my input changes according to classification (Vehicle or Pedestrian)

the models encode function is this :
self.node_models_dict[node].encoder_forward(inputs,
inputs_st,
inputs_np,
robot_present_and_future,
maps)
Example encode inputs are like that:
inputs= {VEHICLE/35521: tensor([[-2.5944e+01, 5.7025e+00, -5.8111e+00, 1.1085e-03, -8.1415e-01,
4.1824e-02, 1.8372e+02, 7.1199e-02]], device=‘cuda:0’)}
inputs_st= {VEHICLE/35521: tensor([[ 0.0000e+00, 0.0000e+00, -3.8740e-01, 7.3899e-05, -2.0354e-01,
1.0456e-02, 5.8481e+01, 7.1199e-02]], device=‘cuda:0’)}
inputs_np= {VEHICLE/35521: array([[-2.59436354e+01, 5.70252158e+00, -5.81107431e+00,
1.10847897e-03, -8.14148775e-01, 4.18236967e-02,
1.83724350e+02, 7.11987019e-02]])}
robot_present_and_future: None
maps: None

The models decode function is this :
model.decoder_forward(num_predicted_timesteps,
num_samples,
robot_present_and_future=robot_present_and_future,
z_mode=z_mode,
gmm_mode=gmm_mode,
full_dist=full_dist,
all_z_sep=all_z_sep)

Example decode inputs arelike that:
node(encoder): VEHICLE/35521
node: VEHICLE/35521
num_predicted_timesteps: 6
num_samples: 1
robot_present_and_future: None
z_mode: False
gmm_mode: False
full_dist: True
all_z_sep: False

Pytorch model_dict was shown below:
name: VEHICLE/node_history_encoder model: LSTM(8, 32, batch_first=True)
name: VEHICLE/node_future_encoder model: LSTM(2, 32, batch_first=True, bidirectional=True)
name: VEHICLE/node_future_encoder/initial_h model: Linear(in_features=8, out_features=32, bias=True)
name: VEHICLE/node_future_encoder/initial_c model: Linear(in_features=8, out_features=32, bias=True)
name: VEHICLE/edge_influence_encoder model: AdditiveAttention(
(w1): Linear(in_features=32, out_features=32, bias=False)
(w2): Linear(in_features=32, out_features=32, bias=False)
(v): Linear(in_features=32, out_features=1, bias=False)
)
name: VEHICLE/p_z_x model: Linear(in_features=64, out_features=32, bias=True)
name: VEHICLE/hx_to_z model: Linear(in_features=32, out_features=25, bias=True)
name: VEHICLE/hxy_to_z model: Linear(in_features=192, out_features=25, bias=True)
name: VEHICLE/decoder/state_action model: Sequential(
(0): Linear(in_features=8, out_features=2, bias=True)
)
name: VEHICLE/decoder/rnn_cell model: GRUCell(91, 128)
name: VEHICLE/decoder/initial_h model: Linear(in_features=89, out_features=128, bias=True)
name: VEHICLE/decoder/proj_to_GMM_log_pis model: Linear(in_features=128, out_features=1, bias=True)
name: VEHICLE/decoder/proj_to_GMM_mus model: Linear(in_features=128, out_features=2, bias=True)
name: VEHICLE/decoder/proj_to_GMM_log_sigmas model: Linear(in_features=128, out_features=2, bias=True)
name: VEHICLE/decoder/proj_to_GMM_corrs model: Linear(in_features=128, out_features=1, bias=True)
name: VEHICLE->VEHICLE/edge_encoder model: LSTM(16, 32, batch_first=True)
name: VEHICLE->PEDESTRIAN/edge_encoder model: LSTM(14, 32, batch_first=True)
name: VEHICLE/unicycle_initializer model: Linear(in_features=65, out_features=1, bias=True)
name: PEDESTRIAN/node_history_encoder model: LSTM(6, 32, batch_first=True)
name: PEDESTRIAN/node_future_encoder model: LSTM(2, 32, batch_first=True, bidirectional=True)
name: PEDESTRIAN/node_future_encoder/initial_h model: Linear(in_features=6, out_features=32, bias=True)
name: PEDESTRIAN/node_future_encoder/initial_c model: Linear(in_features=6, out_features=32, bias=True)
name: PEDESTRIAN/edge_influence_encoder model: AdditiveAttention(
(w1): Linear(in_features=32, out_features=32, bias=False)
(w2): Linear(in_features=32, out_features=32, bias=False)
(v): Linear(in_features=32, out_features=1, bias=False)
)
name: PEDESTRIAN/p_z_x model: Linear(in_features=64, out_features=32, bias=True)
name: PEDESTRIAN/hx_to_z model: Linear(in_features=32, out_features=25, bias=True)
name: PEDESTRIAN/hxy_to_z model: Linear(in_features=192, out_features=25, bias=True)
name: PEDESTRIAN/decoder/state_action model: Sequential(
(0): Linear(in_features=6, out_features=2, bias=True)
)
name: PEDESTRIAN/decoder/rnn_cell model: GRUCell(91, 128)
name: PEDESTRIAN/decoder/initial_h model: Linear(in_features=89, out_features=128, bias=True)
name: PEDESTRIAN/decoder/proj_to_GMM_log_pis model: Linear(in_features=128, out_features=1, bias=True)
name: PEDESTRIAN/decoder/proj_to_GMM_mus model: Linear(in_features=128, out_features=2, bias=True)
name: PEDESTRIAN/decoder/proj_to_GMM_log_sigmas model: Linear(in_features=128, out_features=2, bias=True)
name: PEDESTRIAN/decoder/proj_to_GMM_corrs model: Linear(in_features=128, out_features=1, bias=True)
name: PEDESTRIAN->VEHICLE/edge_encoder model: LSTM(14, 32, batch_first=True)
name: PEDESTRIAN->PEDESTRIAN/edge_encoder model: LSTM(12, 32, batch_first=True)