Slow libtorch/tch-rs - Output Archive trace?

Hello,

I’m working on moving some Python code into Rust, using tch-rs. I’m getting better performance than going back to Python using pyO3 bindings into my Rust, but only barely, which seems wrong.

When I do a trace, my thread is spending almost all of its time doing torch::serialize::OutputArchive::save_to and I’m not sure why, since I’m not intending to output or save anything. I’m going to attempt to attach a call stack example.

Struct that includes the network and then the important bit of the new and then the bit with the forward pass.

pub struct GymWrapper { 
     gym: Gym, 
     net: Box<dyn nn::Module>, 
 }

impl GymWrapper{
pb fn new{
// other code to set up the gym removed...
        let var_store = nn::VarStore::new(Device::Cpu); 
        let net = network::net(&var_store.root()); 
        GymWrapper { gym , net: Box::new(net)} 
     }

    pub fn step_episode(&mut self, seed: Option<u64>) -> bool { 
  
         tch::set_num_threads(1); 
  
         let mut steps = 0; 
         let mut done = false; 
         for _i in 0..10{ 
             done = false; 
             let mut obs = self.gym.reset(Some(false), seed); 
             
             while !done{ 
                 let tens_obs = Tensor::from_slice2(&obs); 
                 let actions: Tensor = tch::no_grad(|| self.net.forward(&tens_obs)); 
                 let act_vec: Vec<Vec<f32>> = Tensor::try_into(actions).expect("error from tensor to vector"); 
                 let result = self.gym.step(act_vec); 
                 obs = result.0;
                 done = result.2; 
                 steps += 1; 
             } 
  
         }

The test is to set up the gym wrapper and then time running the step episode. While step episode is running I took the attached trace. Like 30% of the exclusive time is in the torch::serialize::OutputArchive::save_to.

Any ideas what I’m doing wrong or what it’s trying to do?

Network is below because it’s long and ugly.

const HIDDEN_NODES: i64 = 256; 
 const INPUT_DIM: i64 = 231; 
 const OUTPUT_DIM: i64 = 90;


pub fn net(my_net: &nn::Path) -> impl Module { 
     nn::seq().add(nn::linear(my_net / "layer1", 
                 INPUT_DIM, 
                 HIDDEN_NODES, 
                 nn::LinearConfig{ws_init:Kaiming{dist: Normal, 
                     fan: FanIn, 
                    non_linearity: ReLU},bs_init:Some(Kaiming{dist: Normal, 
                      fan: FanIn, 
                     non_linearity: ReLU}),bias:true})) 
             .add_fn(|xs| xs.leaky_relu()) 
             .add(nn::linear(my_net / "layer2", 
                 HIDDEN_NODES, 
                 HIDDEN_NODES, 
                 nn::LinearConfig{ws_init:Kaiming{dist: Normal, 
                     fan: FanIn, 
                    non_linearity: ReLU},bs_init:Some(Kaiming{dist: Normal, 
                      fan: FanIn, 
                     non_linearity: ReLU}),bias:true})) 
             .add_fn(|xs| xs.leaky_relu()) 
             .add(nn::linear(my_net / "layer3", 
                 HIDDEN_NODES, 
                 HIDDEN_NODES, 
                 nn::LinearConfig{ws_init:Kaiming{dist: Normal, 
                     fan: FanIn, 
                    non_linearity: ReLU},bs_init:Some(Kaiming{dist: Normal, 
                      fan: FanIn, 
                     non_linearity: ReLU}),bias:true})) 
             .add_fn(|xs| xs.leaky_relu()) 
             .add(nn::linear(my_net / "layer4", 
                 HIDDEN_NODES, 
                 OUTPUT_DIM, 
                 nn::LinearConfig{ws_init:Kaiming{dist: Normal, 
                     fan: FanIn, 
                    non_linearity: ReLU},bs_init:Some(Kaiming{dist: Normal, 
                      fan: FanIn, 
                     non_linearity: ReLU}),bias:true})) 
 }

Trace here