i am trying to save a model after training using JIT, but it throws me this error:
NotSupportedError: Compiled functions can't take variable number of arguments or use keyword-only arguments with defaults:
File "/opt/conda/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py", line 147
def forward(self, *inputs, **kwargs):
~~~~~~~ <--- HERE
with torch.autograd.profiler.record_function("DataParallel.forward"):
if not self.device_ids:
here’s the model,
class custom_model(nn.Module):
def __init__(self):
super().__init__()
self.base = base_model
self.mean_max_pool = MeanMaxPool(hidden_size)
self.dropout = nn.Dropout(0.1)
self.fc = nn.Linear(hidden_size*3,inter_size[0])
self.out = nn.Linear(inter_size[0],n_classes)
def forward(self, ids,mask):
x = self.base(ids, mask)[0]
x = self.mean_max_pool(x,mask)
x= self.fc(x)
x=self.dropout(x)
out=self.out(x)
return out
model = custom_model()
model.to(device)
model= nn.DataParallel(model,device_ids=[0, 1])
train_function(model)
model_scripted=torch.jit.script(model)
model_scripted.save('model_scripted.pt')