I saw these kind of topics a lot in forum but generally batch dimension is missing in their tensors. But in my case channel dimension is missing and I don’t know how to just change my model to meet its requirements. Here is my code
class DAE(nn.Module):
def __init__(self):
super(DAE, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(1,8,3,1,1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2)
)
self.conv2 = nn.Sequential(
nn.Conv2d(8, 64, 3, 1, 1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2) )
self.up = nn.UpsamplingBilinear2d(scale_factor=2)
self.down = nn.MaxPool2d(kernel_size=2 , stride=2)
self.deconv1 = nn.Sequential(
nn.Conv2d(64, 8, 3, 1, 1),
nn.ReLU(),
nn.UpsamplingBilinear2d(scale_factor=2)
)
self.deconv2 = nn.Sequential(
nn.Conv2d(8, 1, 3, 1, 1),
nn.ReLU(),
nn.UpsamplingBilinear2d(scale_factor=2)
)
self.linear1 = nn.Linear(64*7*7 , 64 )
self.linear2 = nn.Linear( 64 ,64*7*7 )
def forward(self , x ):
x = self.conv2(self.conv1(x))
# x = self.res5(self.res4(self.res3(self.res2(self.res1(x)))))
x = x.view(x.shape[0] ,64*7*7 )
x = self.linear2(self.linear1(x))
x = x.view(x.shape[0] , 64 ,7 ,7)
x = self.deconv2(self.deconv1(x))
return x
f, axes= plt.subplots(6, 3, figsize = (5, 10))
axes[0,0].set_title("Original Image")
axes[0,1].set_title("Noisy Image")
axes[0,2].set_title("Cleaned Image")
for idx, (noisy, clean, label) in enumerate(test_dataloader):
if idx > 5:
break
# denoising with DAE
noisy = noisy.view(noisy.size(0),-1).type(torch.FloatTensor)
noisy = noisy.to(device)
output = model(noisy)
# fix size
output = output.view(1, 28, 28)
output = output.permute(1, 2, 0).squeeze(2)
output = output.detach().cpu().numpy()
noisy = noisy.view(1, 28, 28)
noisy = noisy.permute(1, 2, 0).squeeze(2)
noisy = noisy.detach().cpu().numpy()
clean = clean.view(1, 28, 28)
clean = clean.permute(1, 2, 0).squeeze(2)
clean = clean.detach().cpu().numpy()
# plot
axes[idx, 0].imshow(clean, cmap="gray")
axes[idx, 1].imshow(noisy, cmap="gray")
axes[idx, 2].imshow(output, cmap="gray")
axes[idx, 0].set(xticks=[], yticks=[])
axes[idx, 1].set(xticks=[], yticks=[])
axes[idx, 2].set(xticks=[], yticks=[])
RuntimeError Traceback (most recent call last)
in ()
13 noisy = noisy.view(noisy.size(0),-1).type(torch.FloatTensor)
14 noisy = noisy.to(device)
—> 15 output = model(noisy)
16
17 # fix size
6 frames
/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
720 result = self._slow_forward(*input, **kwargs)
721 else:
–> 722 result = self.forward(*input, **kwargs)
723 for hook in itertools.chain(
724 _global_forward_hooks.values(),
in forward(self, x)
32
33 def forward(self , x ):
—> 34 x = self.conv2(self.conv1(x))
35 # x = self.res5(self.res4(self.res3(self.res2(self.res1(x)))))
36 x = x.view(x.shape[0] ,6477 )
/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
720 result = self._slow_forward(*input, **kwargs)
721 else:
–> 722 result = self.forward(*input, **kwargs)
723 for hook in itertools.chain(
724 _global_forward_hooks.values(),
/usr/local/lib/python3.6/dist-packages/torch/nn/modules/container.py in forward(self, input)
115 def forward(self, input):
116 for module in self:
–> 117 input = module(input)
118 return input
119
/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
720 result = self._slow_forward(*input, **kwargs)
721 else:
–> 722 result = self.forward(*input, **kwargs)
723 for hook in itertools.chain(
724 _global_forward_hooks.values(),
/usr/local/lib/python3.6/dist-packages/torch/nn/modules/conv.py in forward(self, input)
417
418 def forward(self, input: Tensor) -> Tensor:
–> 419 return self._conv_forward(input, self.weight)
420
421 class Conv3d(_ConvNd):
/usr/local/lib/python3.6/dist-packages/torch/nn/modules/conv.py in _conv_forward(self, input, weight)
414 _pair(0), self.dilation, self.groups)
415 return F.conv2d(input, weight, self.bias, self.stride,
–> 416 self.padding, self.dilation, self.groups)
417
418 def forward(self, input: Tensor) -> Tensor:
RuntimeError: Expected 4-dimensional input for 4-dimensional weight [8, 1, 3, 3], but got 2-dimensional input of size [1, 784] instead