Model.forward() with same input size as in pytorch leads to dimension error in libtorch

hi !
I have this code that is a libtorch cpp inference class for models trained on various platforms. I load my .pth model and try to use it to predict output data but i quickly encounter an error :

tensor sent to cuda
tensor sent to half
leaving toTensor
input_.at(0).size() = [1, 3, 384, 288]
entring inference
terminate called after throwing an instance of 'std::runtime_error'
  what():  The following operation failed in the TorchScript interpreter.
Traceback of TorchScript, serialized code (most recent call last):
  File "code/__torch__/models/yolo.py", line 46, in forward
    _35 = (_4).forward(_34, )
    _36 = (_2).forward((_3).forward(_35, ), _29, )
    _37 = (_0).forward(_33, _35, (_1).forward(_36, ), )
           ~~~~~~~~~~~ <--- HERE
    _38, _39, _40, _41, = _37
    return (_41, [_38, _39, _40])
  File "code/__torch__/models/yolo.py", line 75, in forward
    y = torch.sigmoid(_50)
    _51 = torch.mul(torch.slice(y, 4, 0, 2), CONSTANTS.c0)
    _52 = torch.add(torch.sub(_51, CONSTANTS.c1), CONSTANTS.c2)
          ~~~~~~~~~ <--- HERE
    xy = torch.mul(_52, torch.select(CONSTANTS.c3, 0, 0))
    _53 = torch.mul(torch.slice(y, 4, 2, 4), CONSTANTS.c4)

Traceback of TorchScript, original code (most recent call last):
/home/kubler/data/det_track/yolov5_bis/models/yolo.py(66): forward
/home/kubler/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py(1039): _slow_forward
/home/kubler/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py(1051): _call_impl
/home/kubler/data/det_track/yolov5_bis/models/yolo.py(143): _forward_once
/home/kubler/data/det_track/yolov5_bis/models/yolo.py(121): forward
/home/kubler/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py(1039): _slow_forward
/home/kubler/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py(1051): _call_impl
/home/kubler/anaconda3/lib/python3.8/site-packages/torch/jit/_trace.py(952): trace_module
/home/kubler/anaconda3/lib/python3.8/site-packages/torch/jit/_trace.py(735): trace
export.py(54): export_torchscript
export.py(279): run
/home/kubler/anaconda3/lib/python3.8/site-packages/torch/autograd/grad_mode.py(28): decorate_context
export.py(328): main
export.py(333): <module>
RuntimeError: The size of tensor a (36) must match the size of tensor b (80) at non-singleton dimension 3

The input tensor size of the original repo is img.size = torch.Size([1, 3, 384, 288])
I format my tensor to be of the same type and size .

I already had encountered this error before with YOLO and this was because of input dimensions that didn’t match and I needed to make the input square. At the time, I solved it by padding the image to make it square but here it didn’t work.

Here is my code :

//______________________RUN_________________________
void PoseDetector::run(bool display = false){
  if (camera_.isOpened())
    dataset_ = getCameraFrame();
  else if (! ( getDataPath().empty() ) )
    dataset_ = loadDataset();

  std::cout<< " dataset_.at(0).size() = " << dataset_.at(0).size() <<std::endl;
  input_ = prepocData( dataset_ ,  inputSize_ );
  std::cout<< "input_.at(0).size() = " << input_.at(0).toTensor().sizes() <<std::endl;
  detections_tens_ = inference(input_);
/*
  detections_ = postProcessing(detections_tens_);
  if ( display ){
    demo(  dataset_.at(0) , detections_ , labels_ );
  }
*/
}


//______________________PREPROC_DATA___________________

std::vector<torch::jit::IValue> PoseDetector::prepocData(std::vector<cv::Mat>& matrix_data , cv::Size dst_sz){
  //std::cout << "entering preproc data"<<std::endl;

  cv::Mat src , normalized , resized , padded  ;
  std::vector<torch::jit::IValue> res ;

  for( auto it = std::begin(matrix_data) ; it != std::end(matrix_data) ; it ++  ){

    //________detect_preprocess
    (*it).convertTo( normalized,  CV_32FC3, 1.0 / 255, 0); //normalization and type matching tensor input type as in  yolov5 ultralytics detect.py
    //std::cout << "normalized shape =  " << normalized.cols << " , " << normalized.rows << " , " << normalized.channels() << std::endl ;

    //___________resizing&padding to square the image
    //resize bigger dim to match yolo training input size.
    scale_ = resize( normalized , resized , dst_sz);
    //std::cout << "resized shape =  " << resized.cols << " , " << resized.rows << " , " << resized.channels() << std::endl ;

    padToSize(resized , padded , dst_sz);
    //std::cout << "after pad : padded shape =  " << padded.cols << " , " << padded.rows << " , " << padded.channels() << std::endl ;

    torch::Tensor tens ;
    toTensor(padded , tens);
    //std::cout << "tens.sizes() = " << tens.sizes() << std::endl ;
    //std::cout << "tens.device() = " << tens.device() << std::endl ;

    res.emplace_back(tens);
  }
  //std::cout << "leaving preproc data"<<std::endl;

  return res;
}

//______________________INFERENCE_______________________
torch::Tensor PoseDetector::inference(std::vector<torch::jit::IValue>& input) {
  //runs model inference on loaded model
  std::cout << "entring inference" << std::endl;
  //for ( auto it =  std::begin(input) ; it != std::end(input) ; it ++){
  //  std::cout << "input it device = " << ((torch::Tensor) (*it)).device() << std::endl;
  //}
  torch::jit::IValue output_ = module_.forward(input);
  std::cout << "forward" << std::endl;
  torch::Tensor detections_tens_ = output_.toTuple()->elements()[0].toTensor();
  std::cout << "extracting detections" << std::endl;
  std::cout << "leaving inference" << std::endl;
  return detections_tens_;
}

Model loading :

torch::jit::script::Module Detector::loadDetector(const std::string& path , torch::Device device ) {
  //load torch module from a given path
  //module is the name given to a pytorch machine learning model
  try {
    // Deserialize the script module.torchscript.pth from a file using torch::jit::load().
    torch::jit::script::Module module ;
    module_ = torch::jit::load(path , device);
  }
  catch (const c10::Error& e) {
    std::cerr << "error loading the model located in : " << path << std::endl ;
    std::exit(EXIT_FAILURE);
  }
  //setting up device
  module_.to(device_);
  if (device != torch::kCPU) {
    module_.to(torch::kHalf);
  }
  module_.eval();
  return module_ ;
}```

Thanks for your help !

As you’ve described, the issue is caused by a shape mismatch in the spatial dimension.
Did you trace or script the model before exporting it? The former case would not capture any data-dependent control flow and I would assume that YOLO would need it.

Well I have tried but the program won’t allow me to do so. I’m working with mmpose and the complexity of the repo doesn’t allow that.
I either get a non traced .pt or a .onnx . Would you happend to know how do I open a .onnx on libtorch OR how i can convert this onnx to pytorch simply?

No, unfortunately I don’t know the current support between libtorch and ONNX (and potential tools to convert between them).

Thans for your help @ptrblck
I have finally found a way to do so.
As you said, my model was indeed not traced and this is what led to the error.

I used this repos to transform my onnx module to a pytorch traced module with the following unfininshed-but-you-get-the-idea script that converts onnx to pth and trace the pth model.
(note : this code requires that you have a “torch” folder in your current folder to work well.
I have a current folder that looked like this :

  • onnx2torch.py
  • torch (where torch output models are stored)
  • onnx (where onnx input models are stored)
    )
#onnx2torch.py
import onnx
from onnx2pytorch import ConvertModel
import sys , os
import torch
import numpy as np

def main(argv):
    path_to_onnx = sys.argv[1]
    size = [1,3,384,288] # sys.argv[2]
    #size = np.fromstring(size, dtype=int, sep=',')
    print(size)

    torch_file = path_to_onnx[ path_to_onnx.find('/') :path_to_onnx.find('.') ] + ".pth"
    path_to_torch = os.path.join("torch" , torch_file)
    path_to_torch = "torch" + path_to_torch
    print(f"load path = {path_to_onnx}")
    print(f"save_path = {path_to_torch}")

    onnx_model = onnx.load(path_to_onnx)
    pytorch_model = ConvertModel(onnx_model)

    # An example input you would normally provide to your model's forward() method.
    dummy_input = torch.ones(size)

    # Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
    traced_script_module = torch.jit.trace(pytorch_model, dummy_input)
    traced_script_module.save(path_to_torch)
    print("saved")
    return


if __name__ == "__main__" :
    main(sys.argv)
1 Like