I am working on the following code:
The code is written for RGB images (3 channels), but the dataset I’m working has only single channel. I have changed the line numbers 108 and 109 of dataset.py file to process the single channel.
But I got the following error:
Traceback (most recent call last):
File "train.py", line 68, in <module>
main(cfg)
File "train.py", line 63, in main
solver.fit()
File "F:\CFSRCNN\cfsrcnn_x2\solver.py", line 89, in fit
sr = refiner(lr, scale)
File "C:\Users\anaconda3\envs\CFSRCNN\lib\site-packages\torch\nn\modules\module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "C:\Users\anaconda3\envs\CFSRCNN\lib\site-packages\torch\nn\parallel\data_parallel.py", line 159, in forward
return self.module(*inputs[0], **kwargs[0])
File "C:\Users\anaconda3\envs\CFSRCNN\lib\site-packages\torch\nn\modules\module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "F:\CFSRCNN\cfsrcnn_x2\model\cfsrcnn.py", line 74, in forward
x0 = self.sub_mean(x)
File "C:\Users\anaconda3\envs\CFSRCNN\lib\site-packages\torch\nn\modules\module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "F:\CFSRCNN\cfsrcnn_x2\model\ops.py", line 29, in forward
x = self.shifter(x)
File "C:\Users\anaconda3\envs\CFSRCNN\lib\site-packages\torch\nn\modules\module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "C:\Users\anaconda3\envs\CFSRCNN\lib\site-packages\torch\nn\modules\conv.py", line 423, in forward
return self._conv_forward(input, self.weight)
File "C:\Users\anaconda3\envs\CFSRCNN\lib\site-packages\torch\nn\modules\conv.py", line 420, in _conv_forward
self.padding, self.dilation, self.groups)
RuntimeError: Given groups=1, weight of size [3, 3, 1, 1], expected input[4, 1, 64, 64] to have 3 channels, but got 1 channels instead
I need to change the model/cfsrcnn.py file to fed the single channel. How to change MeanShift. function to process the single channel. Please help me to change this model code so that I can fed the single channel. The MeanShift function is given below:
class MeanShift(nn.Module):
def __init__(self, mean_rgb, sub):
super(MeanShift, self).__init__()
sign = -1 if sub else 1
r = mean_rgb[0] * sign
g = mean_rgb[1] * sign
b = mean_rgb[2] * sign
self.shifter = nn.Conv2d(3, 3, 1, 1, 0) #3 is size of output, 3 is size of input, 1 is kernel 1 is padding, 0 is group
self.shifter.weight.data = torch.eye(3).view(3, 3, 1, 1) # view(3,3,1,1) convert a shape into (3,3,1,1) eye(3) is a 3x3 matrix and diagonal is 1.
self.shifter.bias.data = torch.Tensor([r, g, b])
#in_channels, out_channels,ksize=3, stride=1, pad=1
# Freeze the mean shift layer
for params in self.shifter.parameters():
params.requires_grad = False
def forward(self, x):
x = self.shifter(x)
return x