My simple test code to test the net,but it fails because the CUDA OOM error.
def test(model,test_dir):
model.cuda()
model.eval()
if not os.path.exists(os.path.join(test_dir,'result')):
os.makedirs(os.path.join(test_dir,'result'))
for file in os.listdir(test_dir):
with torch.no_grad():
x=Image.open(os.path.join(test_dir,file))
x.resize((640,480),Image.BILINEAR)
x=to_tensor(x)
x.unsqueeze_(0)
x=x.cuda()
y=model(x)
y=torch.argmax(F.softmax(y,1),1)
y=y.cpu().float()
y=to_pil_image(y)
y.save(os.path.join(test_dir,'result',file.strip('jpg')+'png'))
del x,y
model=Model()
load_checkpoint(epoch=450,[('model', model)])
seg=model.segmentation.state_dict()
model=Segmentaion()
model.load_weights(seg)
test(model,'./test/origin')