Hi there,
I am new to Pytorch and using the pre-trained model for my Django App which restores broken images. I took the help from the code given in this repository.
My ML Model Class’s code goes like this:
class Hourglass(nn.Module):
def __init__(self):
super(Hourglass, self).__init__()
self.leaky_relu = nn.LeakyReLU()
self.d_conv_1 = nn.Conv2d(2, 8, 5, stride=2, padding=2)
self.d_bn_1 = nn.BatchNorm2d(8)
self.d_conv_2 = nn.Conv2d(8, 16, 5, stride=2, padding=2)
self.d_bn_2 = nn.BatchNorm2d(16)
self.d_conv_3 = nn.Conv2d(16, 32, 5, stride=2, padding=2)
self.d_bn_3 = nn.BatchNorm2d(32)
self.s_conv_3 = nn.Conv2d(32, 4, 5, stride=1, padding=2)
self.d_conv_4 = nn.Conv2d(32, 64, 5, stride=2, padding=2)
self.d_bn_4 = nn.BatchNorm2d(64)
self.s_conv_4 = nn.Conv2d(64, 4, 5, stride=1, padding=2)
self.d_conv_5 = nn.Conv2d(64, 128, 5, stride=2, padding=2)
self.d_bn_5 = nn.BatchNorm2d(128)
self.s_conv_5 = nn.Conv2d(128, 4, 5, stride=1, padding=2)
self.d_conv_6 = nn.Conv2d(128, 256, 5, stride=2, padding=2)
self.d_bn_6 = nn.BatchNorm2d(256)
self.u_deconv_5 = nn.ConvTranspose2d(256, 64, 4, stride=2, padding=1)
self.u_bn_5 = nn.BatchNorm2d(64)
self.u_deconv_4 = nn.ConvTranspose2d(128, 32, 4, stride=2, padding=1)
self.u_bn_4 = nn.BatchNorm2d(32)
self.u_deconv_3 = nn.ConvTranspose2d(64, 16, 4, stride=2, padding=1)
self.u_bn_3 = nn.BatchNorm2d(16)
self.u_deconv_2 = nn.ConvTranspose2d(32, 8, 4, stride=2, padding=1)
self.u_bn_2 = nn.BatchNorm2d(8)
self.u_deconv_1 = nn.ConvTranspose2d(16, 4, 4, stride=2, padding=1)
self.u_bn_1 = nn.BatchNorm2d(4)
self.out_deconv = nn.ConvTranspose2d(4, 4, 4, stride=2, padding=1)
self.out_bn = nn.BatchNorm2d(4)
def forward(self, noise):
down_1 = self.d_conv_1(noise)
down_1 = self.d_bn_1(down_1)
down_1 = self.leaky_relu(down_1)
down_2 = self.d_conv_2(down_1)
down_2 = self.d_bn_2(down_2)
down_2 = self.leaky_relu(down_2)
down_3 = self.d_conv_3(down_2)
down_3 = self.d_bn_3(down_3)
down_3 = self.leaky_relu(down_3)
skip_3 = self.s_conv_3(down_3)
down_4 = self.d_conv_4(down_3)
down_4 = self.d_bn_4(down_4)
down_4 = self.leaky_relu(down_4)
skip_4 = self.s_conv_4(down_4)
down_5 = self.d_conv_5(down_4)
down_5 = self.d_bn_5(down_5)
down_5 = self.leaky_relu(down_5)
skip_5 = self.s_conv_5(down_5)
down_6 = self.d_conv_6(down_5)
down_6 = self.d_bn_6(down_6)
down_6 = self.leaky_relu(down_6)
up_5 = self.u_deconv_5(down_6)
up_5 = torch.cat([up_5, skip_5], 1)
up_5 = self.u_bn_5(up_5)
up_5 = self.leaky_relu(up_5)
up_4 = self.u_deconv_4(up_5)
up_4 = torch.cat([up_4, skip_4], 1)
up_4 = self.u_bn_4(up_4)
up_4 = self.leaky_relu(up_4)
up_3 = self.u_deconv_3(up_4)
up_3 = torch.cat([up_3, skip_3], 1)
up_3 = self.u_bn_3(up_3)
up_3 = self.leaky_relu(up_3)
up_2 = self.u_deconv_2(up_3)
up_2 = self.u_bn_2(up_2)
up_2 = self.leaky_relu(up_2)
up_1 = self.u_deconv_1(up_2)
up_1 = self.u_bn_1(up_1)
up_1 = self.leaky_relu(up_1)
out = self.out_deconv(up_1)
out = self.out_bn(out)
out = nn.Sigmoid()(out)
return out
and I am using it in my view to restore image which goes like this:
def restore_image_deep_image_prior(original_image_path):
lr = 1e-2
device = 'cpu'
print('Using {} for computation'.format(device.upper()))
hg_net = Hourglass()
hg_net.to(device)
mse = nn.MSELoss()
optimizer = optim.Adam(hg_net.parameters(), lr=lr)
n_iter = 500
images = []
losses = []
to_tensor = tv.transforms.ToTensor()
z = torch.Tensor(np.mgrid[:542, :347]).unsqueeze(0).to(device) / 512 # Adjust the size according to your requirement
x = PILImage.open(original_image_path)
x = to_tensor(x).unsqueeze(0)
x, mask = pixel_thanos(x, 0.5)
mask = mask[:, :3, :, :].to(device).float() # Keep only the first 3 channels if mask has 4 channels
x = x.to(device)
for i in range(n_iter):
optimizer.zero_grad()
y = hg_net(z)
loss = mse(x, y*mask)
losses.append(loss.item())
loss.backward()
optimizer.step()
if i < 1000 and (i+1)%4==0 or i==0:
with torch.no_grad():
out = x + y * (1 - mask)
out = out[0].cpu().detach().permute(1, 2, 0)*255
out = np.array(out, np.uint8)
images.append(out)
if (i+1) % 20 == 0:
print('Iteration: {} Loss: {:.07f}'.format(i+1, losses[-1]))
restored_image_bytes = image_to_bytes(images)
return restored_image_bytes
def pixel_thanos(img, p=0.5):
assert p > 0 and p < 1, 'The probability value should lie in (0, 1)'
mask = torch.rand(1, 3, 542, 347)
img[mask < p,] = 0
mask = mask > p
mask = mask.repeat(1, 3, 1, 1)
return img, mask
def image_to_bytes(image):
buffer = BytesIO()
image.save(buffer, format="JPEG")
return buffer.getvalue()
Now, the problem is I am getting the error listen in the title as well, which in detail, is:
Traceback (most recent call last):
File "C:\Users\khubi\AppData\Local\Programs\Python\Python39\lib\site-packages\django\core\handlers\exception.py", line 55, in inner
response = get_response(request)
File "C:\Users\khubi\AppData\Local\Programs\Python\Python39\lib\site-packages\django\core\handlers\base.py", line 197, in _get_response
response = wrapped_callback(request, *callback_args, **callback_kwargs)
File "C:\Users\khubi\Desktop\Projects\VU\CS619\image restorer\image_restorer\image_app\views.py", line 35, in restore_image
restored_image = restore_image_deep_image_prior(original_image.image_file.path)
File "C:\Users\khubi\Desktop\Projects\VU\CS619\image restorer\image_restorer\image_app\views.py", line 76, in restore_image_deep_image_prior
y = hg_net(z)
File "C:\Users\khubi\AppData\Local\Programs\Python\Python39\lib\site-packages\torch\nn\modules\module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "C:\Users\khubi\AppData\Local\Programs\Python\Python39\lib\site-packages\torch\nn\modules\module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "C:\Users\khubi\Desktop\Projects\VU\CS619\image restorer\image_restorer\image_app\views.py", line 182, in forward
up_5 = torch.cat([up_5, skip_5], 1)
RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 18 but got size 17 for tensor number 1 in the list.
I know that I had to set the channels. I set it to 4 as the image which I am using for restoration has 4 channels. But still getting the error. I also took help from ChatGPT but all in vain. Can anybody help me, please?