Calling torch::jit::Module::forward-function multiple times crashes in assert(initialized());

Hi all,

I am trying to run the optical flow estimator RAFT (More explicitly, the raft-small.pth-version from GitHub - princeton-vl/RAFT, downloadable there under “Demos”) as a scripted model in Libtorch with CUDA on Windows and I get a weird error. It works perfectly to run the model once (in some contexts, even twice) and crashes at a later forward(…)-call:
Assertion failed: initialized(), file C:\actions-runner_work\pytorch\pytorch\builder\windows\pytorch\c10/util/Optional.h, line 763

I am using Libtorch 1.13.1.
This looks somewhat similar to Second forward call of torchscripted module breaks on cuda, but there, the issue seems not to be in the initialized()-call, but in some JIT fusers(?) of which I don’t know much and it is in Python compared to C++ on myside (the Python inference works!).

The scripted model is created in Python as follows:

import torch
from argparse import Namespace
from RAFT.core.raft import RAFT
import torch.nn.functional as F
import torchvision.transforms.functional as fn


def preprocess(image1: torch.Tensor, image2: torch.Tensor):
    img_curr = image1.permute(2, 0, 1).float()[None]
    im_curr_pad = F.pad(img_curr, [0, 0, 2, 2], mode='replicate')

    img_prev = image2.permute(2, 0, 1).float()[None]
    im_prev_pad = F.pad(img_prev, [0, 0, 2, 2], mode='replicate')
    return img_curr, im_curr_pad, im_prev_pad


def postprocess(image1: torch.Tensor, flow: torch.Tensor):
    flow_unpad = flow[..., 2:182, 0:320]  # unpad

    resized = fn.resize(flow_unpad, size=[144, 256])
    flow_out = resized.permute(0, 2, 3, 1).contiguous()

    resized = fn.resize(image1, size=[144, 256])
    im_curr_out = resized.permute(0, 2, 3, 1).contiguous()

    return im_curr_out, flow_out


class PrePostModule(torch.nn.Module):
    def __init__(self, traced_model):
        super().__init__()
        self.traced_model = traced_model

    def forward(self, im1, im2):

        im_curr_float, im_curr_pad, im_prev_pad = preprocess(im1, im2)

        _, flow = self.traced_model(im_curr_pad, im_prev_pad, torch.tensor(6, dtype=torch.int, requires_grad=False),
                                    torch.tensor(1, dtype=torch.bool, requires_grad=False))
        im_curr_out, flow_out = postprocess(im_curr_float, flow)

        return im_curr_out, flow_out


if __name__ == "__main__":
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    iters = 6
    
    args = Namespace(model="raft-small.pth", path='demo-frames', small=True, mixed_precision=False, alternate_corr=False)
    
    model = torch.nn.DataParallel(RAFT(args))
    model.load_state_dict(torch.load(args.model))
    
    model = model.module
    model.to(device)
    model.eval()
    
    with torch.no_grad():
        # An example input you would normally provide to your model's forward() method.
        im1 = torch.ones((1, 3, 180, 320), device='cuda')
        im2 = torch.ones((1, 3, 180, 320), device='cuda')
    
        im1 = F.pad(im1, [0, 0, 2, 2], mode='replicate')
        im2 = F.pad(im2, [0, 0, 2, 2], mode='replicate')
    
        # Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
        # Script is the better method here but did not work yet, so using tracing with constant input params
        # Later, after adding pre/postprocessing, a scripted model can be generated
        traced_model = torch.jit.trace(model, (im1, im2,
                                               torch.tensor(iters, dtype=torch.int, requires_grad=False),
                                               torch.tensor(1, dtype=torch.bool, requires_grad=False)
                                               )
                                       )
    
        # include pre and post processing and save
        pre_post_raft = PrePostModule(traced_model)
        scripted_model = torch.jit.script(pre_post_raft)
        torch.jit.save(scripted_model, "raft-small-traced.pth")

A minimal working example of the code run in C++ would be

#define PUSH_WARNINGS()             \
    __pragma(warning(push))

#define POP_WARNINGS()              \
    __pragma(warning(pop))

#define DISABLE_MSVC_WARNING(x)     \
    __pragma(warning(disable: x))

#define PUSH_LIB_LIBTORCH_WARNINGS()                                                                    \
  PUSH_WARNINGS()                                                                              \
  DISABLE_MSVC_WARNING(4251)                                                                   \
  DISABLE_MSVC_WARNING(4100)                                                                   \
  DISABLE_MSVC_WARNING(4244)                                                                   \
  DISABLE_MSVC_WARNING(4624)                                                                   \
  DISABLE_MSVC_WARNING(4267)                                                                   \
  DISABLE_MSVC_WARNING(4996)                                                                   \
  DISABLE_MSVC_WARNING(4805)                                                                   \
  DISABLE_MSVC_WARNING(4275)                                                                   \
  DISABLE_MSVC_WARNING(4702)                                                                   \
  DISABLE_MSVC_WARNING(4127)                                                                   \
  DISABLE_MSVC_WARNING(4458)                                                                   \
  DISABLE_MSVC_WARNING(4067)                                                                   \
  DISABLE_MSVC_WARNING(4324)    

#define POP_LIB_LIBTORCH_WARNINGS() POP_WARNINGS()

PUSH_LIB_LIBTORCH_WARNINGS()
#include <torch/script.h>
#include <torch/torch.h>
POP_LIB_LIBTORCH_WARNINGS()
#include <iostream>
#include "opencv2/core.hpp"

int main()
{
    std::string modelName = "C:/Users/ogcstein/PycharmProjects/AICenteringTraining-zeiss/pipeline/AICenteringTraining/raft_traced/raft-small-traced.pth";
    bool val = torch::cuda::is_available();
    try
    {
        torch::jit::Module m_model;
        torch::Device m_device = torch::Device(torch::kCUDA);

        m_model = torch::jit::load(modelName);
        m_model.to(m_device);
        m_model.eval();

        cv::Mat im_curr_small = cv::Mat::zeros(cv::Size(320, 180), CV_8UC3);
        cv::Mat im_prev_small = cv::Mat::zeros(cv::Size(320, 180), CV_8UC3);

        torch::Tensor flow, im;
        for (auto i = 0; i < 100; i++)
        {

            std::vector<int64_t> dims = { 180, 320, 3 };
            torch::Tensor t1 = torch::from_blob(im_curr_small.data, dims, torch::dtype(torch::kU8)).to(m_device);
            torch::Tensor t2 = torch::from_blob(im_prev_small.data, dims, torch::dtype(torch::kU8)).to(m_device);

            torch::jit::Stack instack;
            instack.push_back(t1);
            instack.push_back(t2);

            auto modelOutput = m_model.forward(instack);

            flow = modelOutput.toTuple()->elements()[1].toTensor();
            im = modelOutput.toTuple()->elements()[0].toTensor();
        }
    }
    catch (const c10::Error e)
    {
        std::string str = std::string("Error: ") + std::string(e.what());
        return 1;
    }
    return 0;
}

1 Like

Same here. Any news on that issue?

@Alok_Wessel You can refer to my post here: ASSERT(initialized()) Debug Error after JIT fusion on Windows · Issue #94908 · pytorch/pytorch · GitHub
It’s definitely a problem connected with the Pytorch JIT fusers, I did not yet get any updates there (and probably never will), but I posted a workaround in the first comment!

1 Like

Your patch seems to work for me as well. Merci!

I ran into the same problem, the key was whether to disable gradients or not. I also said this in the github link