Hello to all.
I have a trained model that is only able to accept inputs of size 1024 * 1024(which is resized from the original image) and produce outputs of the same size.
However, I need it to produce outputs of full resolutions, which is no bigger than 5000 * 5000. I wonder if my implementation is wrong or my model is simply too big to produce such output? (According to my knowledge that should not be the case, cause I’ve seen others produce 1024 * 1024 outputs with a model that’s composed of 100+ conv layers.)
Any advice would be appreciated!
Here are my model, and my testing procedure.
class model(nn.Module):
def __init__(self, pretrain=False, IENetPath="", InvertIENetPath=""):
super(model, self).__init__()
self.IENetwork = UNet(3, 3) # UNet is composed of about 20 conv layers
self.InvertedIENetwork = UNet(3, 3)
if pretrain:
self.IENetwork.load_state_dict(torch.load(IENetPath))
self.InvertedIENetwork.load_state_dict(torch.load(InvertIENetPath))
self.conv = conv3x3_ReLU(3, 64)
self.attention = AttentionModule(128, 64) # AttentionModule is compsed of 2 conv layers
self.conv2 = conv3x3_ReLU(192, 64)
self.DRDB1 = RDB(64, 6, 32) # RDB is composed of 6 conv layers
self.DRDB2 = RDB(64, 6, 32)
self.DRDB3 = RDB(64, 6, 32)
self.conv3 = conv3x3_ReLU(192, 64)
self.conv4 = conv3x3_ReLU(64, 64)
self.conv5 = conv3x3_ReLU(64, 3)
with torch.no_grad():
dataset = mydataset(ROOT, TEST_DIR, LABEL_DIR, TEST_IMG_NUM, TEST_TRANSFORM, training=True)
model = model().cuda()
testloader = data.DataLoader(dataset, batch_size=1)
model.load_state_dict(torch.load(MODEL_PATH))
model.eval()
if not os.path.exists(RES_DIR):
os.makedirs(RES_DIR)
for i, data in enumerate(testloader):
Input, Inverted_Input, Label = data['train_img'], data['inverted_img'], data['label_img']
savefolder = RES_DIR + "/test" + str(i)
# Images saving
SaveImg(Input, savefolder + "/1input.jpg")
SaveImg(Inverted_Input, savefolder + "/2inverted.jpg")
Output, IEmap, IEmap_Inverted = model(Input, Inverted_Input, Label, path=savefolder, training=TRAIN)
SaveImg(Output, savefolder + "/7finalRes.jpg")
SaveImg(Label, savefolder + "/8label.jpg")
print("TestEnds")