Mixing GPU/CPU code in TorchScript traced model

Is it possible to TorchScript trace a wrapper module, which does preprocessing on CPU, then a feed-forward network on GPU? I’m trying the below, but I’m not sure if .cpu() and .cuda() mean anything within the TorchScript traced module. The below runs, but then when I run in C++, I get an error about expected device cpu but got device cuda:0, basically in the traced code part after preprocess. In C++, I do have a lines to put the model on the GPU, I’m not sure if that is interfering, and in general if such lines are redundant if the tracing in Python already has .cuda() calls (also, I know about torch::NoGradGuard no_grad; in C++, am I able to avoid that if I use with torch.no_grad() in the tracing Python code?)

C++ code

        model = torch::jit::load("traced_auglag_best.pt");
        model.to(at::kCUDA);
        model.eval(); //not sure if necessary, was .eval() in the tracing code

Python tracing code

class Wrapper(torch.nn.Module):

    #mean_f: Final[float]
    #std_f: Final[float]
    #mean_fdf: Final[float]
    #std_fdf: Final[float]

    def __init__(self,model,file_stats):
        super(Wrapper, self).__init__()
        self.file_stats = file_stats
        self.read_stats()
        self.model = model
        self.model = self.model.cuda()
        self.model.eval()

    def read_stats(self):
        fh = h5py.File(self.file_stats,'r')
        attrs = ['mean_f','std_f','mean_fdf','std_fdf']
        for attr in attrs:
            setattr(self,attr,torch.nn.Parameter(torch.from_numpy(fh[attr][...])))
        fh.close()

    def normalize_f(self,f):
        return ((f - self.mean_f[None,...])/self.std_f[None,...]).float()

    def preprocess(self,f):
        #remove negative inds
        f[f<0] = torch.tensor(0.0).double()
        #if adiabatic electron, add dimension for electrons
        #if torch.tensor(f.size())[0]==torch.tensor(1): #(assume for now adiabatic electrons, otherwise trace issue)
        f = F.pad(f,(0,0,0,0,0,0,1,0))
        #switch order for pytorch model to [Ngrid,Nsp,Nmu,Nvpara]
        f = f.permute(2,0,1,3)
        #pad mu direction (so 32,32)
        f = F.pad(f,(0,1,0,0),mode='replicate')
        #normalize and convert to float for input to model
        return f

    def postprocess(self,fdfnorm,f,isp=1):
        #unnormalize
        df = fdfnorm*self.std_fdf[None,[isp],...] + self.mean_fdf[None,[isp],...] - f[:,[isp],...]
        #remove extra vpara dimension
        df = df[:,:,:,:-1]
        #switch order back to XGC order of [Nsp,Nmu,Ngrid,Nvpara] and convert to double
        return df.permute(1,2,0,3).double()

    def forward(self,f):
        with torch.no_grad():
            fpre = self.preprocess(f)
            fnorm = self.normalize_f(fpre)
            out = self.model(fnorm.cuda()).cpu()
        return self.postprocess(out,fpre)