Is it possible to TorchScript trace a wrapper module, which does preprocessing on CPU, then a feed-forward network on GPU? I’m trying the below, but I’m not sure if .cpu()
and .cuda()
mean anything within the TorchScript traced module. The below runs, but then when I run in C++, I get an error about expected device cpu but got device cuda:0
, basically in the traced code part after preprocess. In C++, I do have a lines to put the model on the GPU, I’m not sure if that is interfering, and in general if such lines are redundant if the tracing in Python already has .cuda()
calls (also, I know about torch::NoGradGuard no_grad;
in C++, am I able to avoid that if I use with torch.no_grad()
in the tracing Python code?)
C++ code
model = torch::jit::load("traced_auglag_best.pt");
model.to(at::kCUDA);
model.eval(); //not sure if necessary, was .eval() in the tracing code
Python tracing code
class Wrapper(torch.nn.Module):
#mean_f: Final[float]
#std_f: Final[float]
#mean_fdf: Final[float]
#std_fdf: Final[float]
def __init__(self,model,file_stats):
super(Wrapper, self).__init__()
self.file_stats = file_stats
self.read_stats()
self.model = model
self.model = self.model.cuda()
self.model.eval()
def read_stats(self):
fh = h5py.File(self.file_stats,'r')
attrs = ['mean_f','std_f','mean_fdf','std_fdf']
for attr in attrs:
setattr(self,attr,torch.nn.Parameter(torch.from_numpy(fh[attr][...])))
fh.close()
def normalize_f(self,f):
return ((f - self.mean_f[None,...])/self.std_f[None,...]).float()
def preprocess(self,f):
#remove negative inds
f[f<0] = torch.tensor(0.0).double()
#if adiabatic electron, add dimension for electrons
#if torch.tensor(f.size())[0]==torch.tensor(1): #(assume for now adiabatic electrons, otherwise trace issue)
f = F.pad(f,(0,0,0,0,0,0,1,0))
#switch order for pytorch model to [Ngrid,Nsp,Nmu,Nvpara]
f = f.permute(2,0,1,3)
#pad mu direction (so 32,32)
f = F.pad(f,(0,1,0,0),mode='replicate')
#normalize and convert to float for input to model
return f
def postprocess(self,fdfnorm,f,isp=1):
#unnormalize
df = fdfnorm*self.std_fdf[None,[isp],...] + self.mean_fdf[None,[isp],...] - f[:,[isp],...]
#remove extra vpara dimension
df = df[:,:,:,:-1]
#switch order back to XGC order of [Nsp,Nmu,Ngrid,Nvpara] and convert to double
return df.permute(1,2,0,3).double()
def forward(self,f):
with torch.no_grad():
fpre = self.preprocess(f)
fnorm = self.normalize_f(fpre)
out = self.model(fnorm.cuda()).cpu()
return self.postprocess(out,fpre)