I train a model on two gpus, and then save the model like this:
net = nn.DataParallel(net)
.....
torch.save(net, save_path)
Then I load the model and run inference with single gpu:
CUDA_VISIBLE_DEIVCES=0 python infer.py
the infer.py
is like this:
im = cv2.imread('./cropped.jpg')
im = cv2.resize(im, (224, 224)).transpose(2, 0, 1)
im = torch.tensor([im, im, im], dtype = torch.float32)
model_path = './res/model.pytorch'
model = torch.load(model_path)
model.eval()
out = model(im).detach().cpu().numpy()
I met the error message like this:
test()
File "infer.py", line 25, in test
out = model(im).detach().cpu().numpy()
File "/home/zhangzy/.local/lib/python2.7/site-packages/torch/nn/modules/module.py", line 491, in __call__
result = self.forward(*input, **kwargs)
File "/home/zhangzy/.local/lib/python2.7/site-packages/torch/nn/parallel/data_parallel.py", line 110, in forward
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
File "/home/zhangzy/.local/lib/python2.7/site-packages/torch/nn/parallel/data_parallel.py", line 121, in scatter
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
File "/home/zhangzy/.local/lib/python2.7/site-packages/torch/nn/parallel/scatter_gather.py", line 36, in scatter_kwargs
inputs = scatter(inputs, target_gpus, dim) if inputs else []
File "/home/zhangzy/.local/lib/python2.7/site-packages/torch/nn/parallel/scatter_gather.py", line 29, in scatter
return scatter_map(inputs)
File "/home/zhangzy/.local/lib/python2.7/site-packages/torch/nn/parallel/scatter_gather.py", line 16, in scatter_map
return list(zip(*map(scatter_map, obj)))
File "/home/zhangzy/.local/lib/python2.7/site-packages/torch/nn/parallel/scatter_gather.py", line 14, in scatter_map
return Scatter.apply(target_gpus, None, dim, obj)
File "/home/zhangzy/.local/lib/python2.7/site-packages/torch/nn/parallel/_functions.py", line 73, in forward
streams = [_get_stream(device) for device in ctx.target_gpus]
File "/home/zhangzy/.local/lib/python2.7/site-packages/torch/nn/parallel/_functions.py", line 100, in _get_stream
if _streams[device] is None:
IndexError: list index out of range
What is wrong with this code then?